pytorch学习五、深度学习计算

来自于 https://tangshusen.me/Dive-into-DL-PyTorch/#/

官方文档 https://pytorch.org/docs/stable/tensors.html

模型构建

Module类是nn模块里提供的一个模型构造类,是所有神经网络模块的基类,我们可以继承它来定义我们想要的模型。下面继承Module类构造本节开头提到的多层感知机。这里定义的MLP类重载了Module类的__init__函数和forward函数。它们分别用于创建模型参数和定义前向计算。前向计算也即正向传播。

import torch
from torch import nn

class MLP(nn.Module):
    # 声明带有模型参数的层,这里声明了两个全连接层
    def __init__(self,**kwargs):
        # 调用MLP父类Module的构造函数来进行必要的初始化。这样在构造实例时还可以指定其他函数
        # 参数,如“模型参数的访问、初始化和共享”一节将介绍的模型参数params
        super(MLP,self).__init__(**kwargs)
        self.hidden = nn.Linear(784,256)
        self.act = nn.ReLU()
        self.output = nn.Linear(256,10)
    
    # 定义模型的前向计算,即如何根据输入x计算返回所需要的模型输出
    def forward(self,x):
        a = self.act(self.hidden(x))
        return self.output(a)
    

实例化MLP

x = torch.rand(2,784)
net = MLP()
print(net)
net(x)
MLP(
  (hidden): Linear(in_features=784, out_features=256, bias=True)
  (act): ReLU()
  (output): Linear(in_features=256, out_features=10, bias=True)
)





tensor([[-0.0370, -0.0794,  0.1573, -0.0460, -0.0318, -0.0653, -0.2214,  0.0340,
          0.0324,  0.1943],
        [-0.1231,  0.0116,  0.2601, -0.0035,  0.0520, -0.0406, -0.0722,  0.0060,
         -0.1080,  0.3512]], grad_fn=<AddmmBackward>)

Module的子类

Sequential

当模型的前向计算为简单串联各个层的计算时,Sequential类可以通过更加简单的方式定义模型。这正是Sequential类的目的:它可以接收一个子模块的有序字典(OrderedDict)或者一系列子模块作为参数来逐一添加Module的实例,而模型的前向计算就是将这些实例按添加的顺序逐一计算。

下面我们实现一个与Sequential类有相同功能的MySequential类。这或许可以帮助读者更加清晰地理解Sequential类的工作机制。

class MySequential(nn.Module):
    from collections import OrderedDict
    def __init__(self,*args):
        super(MySequential,self).__init__()
        if len(args) ==1 and isinstance(args[0],OrderedDict):
            for key,module in args[0].item():
                self.add_module(key,module)
        else:
            for idx ,module in enumerate(args):
                self.add_module(str(idx),module)
    def forward(self,input):
        
        for module in self._modules.values():
            input=module(input)
        return input

net = MySequential(
    nn.Linear(784,256),
    nn.ReLU(),
    nn.Linear(256,10)
)
print(net)
net(x)
MySequential(
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
)





tensor([[ 0.1224, -0.0237, -0.0957,  0.0353, -0.0509, -0.1116, -0.0750, -0.0186,
         -0.0620,  0.1194],
        [ 0.1232, -0.0997, -0.1981, -0.0289,  0.0448,  0.0705, -0.1342,  0.0394,
         -0.1696,  0.1937]], grad_fn=<AddmmBackward>)

ModuleList

ModuleList接收一个子模块的列表作为输入,然后也可以类似List那样进行append和extend操作:

net = nn.ModuleList([nn.Linear(784,256),nn.ReLU()])
net.append(nn.Linear(256,10))
print(net[-1])
print(net)
Linear(in_features=256, out_features=10, bias=True)
ModuleList(
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
)

既然Sequential和ModuleList都可以进行列表化构造网络,那二者区别是什么呢。ModuleList仅仅是一个储存各种模块的列表,这些模块之间没有联系也没有顺序(所以不用保证相邻层的输入输出维度匹配),而且没有实现forward功能需要自己实现,所以上面执行net(torch.zeros(1, 784))会报NotImplementedError;而Sequential内的模块需要按照顺序排列,要保证相邻层的输入输出大小相匹配,内部forward功能已经实现。

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

    def forward(self, x):
        # ModuleList can act as an iterable, or be indexed using ints
        for i, l in enumerate(self.linears):
            x = self.linears[i // 2](x) + l(x)
        return x

另外,ModuleList不同于一般的Python的list,加入到ModuleList里面的所有模块的参数会被自动添加到整个网络中,下面看一个例子对比一下。

class Module_ModuleList(nn.Module):
    def __init__(self):
        super(Module_ModuleList,self).__init__()
        self.lineears = nn.ModuleList([nn.Linear(10,10)])

class Module_List(nn.Module):
    def __init__(self):
        super(Module_List,self).__init__()
        self.linears  = [nn.Linear(10,10)]

net1 = Module_ModuleList()
net2 = Module_List()
print('net1:')
for p in net1.parameters():
    print(p.size())
print('net2:')
for p in net2.parameters():

    print(p.size())
net1:
torch.Size([10, 10])
torch.Size([10])
net2:

ModuleDict

ModuleDict接收一个子模块的字典作为输入, 然后也可以类似字典那样进行添加访问操作:

net = nn.ModuleDict({
    'linear': nn.Linear(784, 256),
    'act': nn.ReLU(),
})
net['output'] = nn.Linear(256, 10) # 添加
print(net['linear']) # 访问
print(net.output)
print(net)
Linear(in_features=784, out_features=256, bias=True)
Linear(in_features=256, out_features=10, bias=True)
ModuleDict(
  (act): ReLU()
  (linear): Linear(in_features=784, out_features=256, bias=True)
  (output): Linear(in_features=256, out_features=10, bias=True)
)

和ModuleList一样,ModuleDict实例仅仅是存放了一些模块的字典,并没有定义forward函数需要自己定义。同样,ModuleDict也与Python的Dict有所不同,ModuleDict里的所有模块的参数会被自动添加到整个网络中。

构造复杂的模型

虽然上面介绍的这些类可以使模型构造更加简单,且不需要定义forward函数,但直接继承Module类可以极大地拓展模型构造的灵活性。下面我们构造一个稍微复杂点的网络FancyMLP。在这个网络中,我们通过get_constant函数创建训练中不被迭代的参数,即常数参数。在前向计算中,除了使用创建的常数参数外,我们还使用Tensor的函数和Python的控制流,并多次调用相同的层。

class FancyMLP(nn.Module):
    def __init__(self,**kwargs):
        super(FancyMLP,self).__init__(**kwargs)
        self.rand_weight = torch.rand((20,20),requires_grad=False)
        self.linear = nn.Linear(20,20)
    def forward(self,x):
        x = self.linear(x)
        
        x=nn.functional.relu(torch.mm(x,self.rand_weight.data)+1)
        
        
        # 复用全连接层。等价于两个全连接层共享参数
        x=self.linear(x)
        
        while x.norm().item()>1:
            x/=2
        if x.norm().item() <0.8:
            x*=10
        return x.sum()

X = torch.rand(2, 20)
net = FancyMLP()
print(net)
net(X)
FancyMLP(
  (linear): Linear(in_features=20, out_features=20, bias=True)
)





tensor(-0.4913, grad_fn=<SumBackward0>)

因为FancyMLP和Sequential类都是Module类的子类,所以我们可以嵌套调用它们。

class NestMLP(nn.Module):
    def __init__(self, **kwargs):
        super(NestMLP, self).__init__(**kwargs)
        self.net = nn.Sequential(nn.Linear(40, 30), nn.ReLU()) 

    def forward(self, x):
        return self.net(x)

net = nn.Sequential(NestMLP(), nn.Linear(30, 20), FancyMLP())

X = torch.rand(2, 40)
print(net)
net(X)
 
Sequential(
  (0): NestMLP(
    (net): Sequential(
      (0): Linear(in_features=40, out_features=30, bias=True)
      (1): ReLU()
    )
  )
  (1): Linear(in_features=30, out_features=20, bias=True)
  (2): FancyMLP(
    (linear): Linear(in_features=20, out_features=20, bias=True)
  )
)





tensor(0.3448, grad_fn=<SumBackward0>)

访问模型参数

print(type(net.named_parameters()))
for name, param in net.named_parameters():
    print(name, param.size())
<class 'generator'>
0.net.0.weight torch.Size([30, 40])
0.net.0.bias torch.Size([30])
1.weight torch.Size([20, 30])
1.bias torch.Size([20])
2.linear.weight torch.Size([20, 20])
2.linear.bias torch.Size([20])

可见返回的名字自动加上了层数的索引作为前缀。 我们再来访问net中单层的参数。对于使用Sequential类构造的神经网络,我们可以通过方括号[]来访问网络的任一层。索引0表示隐藏层为Sequential实例最先添加的层。

for name, param in net[0].named_parameters():
    print(name, param.size(), type(param))
net.0.weight torch.Size([30, 40]) <class 'torch.nn.parameter.Parameter'>
net.0.bias torch.Size([30]) <class 'torch.nn.parameter.Parameter'>

返回的param的类型为torch.nn.parameter.Parameter,其实这是Tensor的子类,和Tensor不同的是如果一个Tensor是Parameter,那么它会自动被添加到模型的参数列表里,来看下面这个例子。

class MyModel(nn.Module):
    def __init__(self,**kwargs):
        super(MyModel,self).__init__(**kwargs)
        self.weight1 = nn.Parameter(torch.rand(20,20))
        self.weight2 = torch.rand(20,20)
    def forward(self,x):
        pass
n = MyModel()
for name,param in n.named_parameters():
    print(name)
weight1

因为Parameter是Tensor,即Tensor拥有的属性它都有,比如可以根据data来访问参数数值,用grad来访问参数梯度。

weight_0 = list(net[0].parameters())[0]
print(weight_0.data)
print(weight_0.grad)
Y=net(X)
Y.backward()
print(weight_0.grad)

tensor([[ 0.0622,  0.0654,  0.1042,  ..., -0.1514,  0.0846,  0.1257],
        [ 0.0139,  0.0422, -0.1171,  ...,  0.0142, -0.0565, -0.1016],
        [ 0.0405,  0.1393,  0.0782,  ...,  0.0151, -0.0972, -0.1105],
        ...,
        [ 0.1332,  0.0998, -0.0605,  ..., -0.0061,  0.0571, -0.1063],
        [-0.1366, -0.1384,  0.0207,  ..., -0.0695,  0.0810, -0.0280],
        [ 0.0659,  0.1505,  0.1109,  ..., -0.0085, -0.0435, -0.0354]])
None
tensor([[-0.0022, -0.0218, -0.0053,  ..., -0.0030, -0.0048, -0.0103],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0864,  0.3693,  0.1030,  ...,  0.1747,  0.1237,  0.4375],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0101, -0.0433, -0.0121,  ..., -0.0205, -0.0145, -0.0514]])

初始化模型参数

from torch.nn import init
for name, param in net.named_parameters():
    if 'weight' in name:
        init.normal_(param, mean=0, std=0.01)
        print(name, param.data)
0.net.0.weight tensor([[ 0.0216, -0.0093,  0.0016,  ...,  0.0019,  0.0130, -0.0026],
        [ 0.0042, -0.0017,  0.0056,  ..., -0.0125, -0.0168, -0.0058],
        [-0.0046, -0.0081,  0.0077,  ...,  0.0097, -0.0123, -0.0045],
        ...,
        [-0.0062, -0.0014, -0.0027,  ...,  0.0026,  0.0015, -0.0152],
        [-0.0068, -0.0066, -0.0049,  ...,  0.0021,  0.0121, -0.0018],
        [ 0.0132,  0.0139,  0.0135,  ...,  0.0050,  0.0105, -0.0064]])
1.weight tensor([[ 2.0540e-02,  1.3424e-02, -1.4864e-02, -1.3467e-02, -9.4628e-03,
          9.7067e-04, -1.5009e-02, -1.5539e-02, -7.9077e-03,  1.0551e-02,
         -9.0278e-03,  1.7459e-02, -1.8690e-03,  5.0541e-03, -1.9378e-03,
          1.6113e-02,  2.4179e-03, -3.9226e-03, -2.4951e-03, -1.0519e-02,
         -1.7979e-03, -3.0638e-03,  5.2334e-04,  1.1851e-02, -1.5307e-02,
         -2.3555e-02, -1.2260e-02, -8.4837e-03,  1.4811e-03, -8.6963e-03],
        [-7.8196e-03,  1.4540e-02,  2.6614e-03, -1.0711e-02,  1.5188e-02,
          3.4464e-04, -1.6673e-02,  1.4700e-02, -1.1430e-02, -3.0225e-03,
          1.7065e-03, -1.7021e-02,  1.0656e-02,  1.1011e-02,  4.2047e-03,
         -1.2465e-02, -6.2476e-03, -1.0956e-02,  3.1987e-03, -4.9426e-03,
         -1.2004e-03,  4.4904e-03, -1.0595e-02, -1.1566e-02,  1.2842e-02,
          2.6868e-03,  4.0029e-03,  1.3305e-02,  4.9962e-03,  3.2664e-03],
        [-3.7490e-03, -9.1373e-03, -7.6418e-03,  1.1468e-02, -7.3277e-03,
          1.0504e-02, -1.2510e-02,  2.0066e-02, -3.8254e-03, -1.3952e-03,
         -8.9643e-03, -7.9015e-03,  3.0163e-03,  4.8269e-04, -1.6445e-02,
         -4.2277e-03, -1.2489e-02,  7.0872e-03, -4.0340e-04, -6.6422e-03,
          2.0632e-03, -4.4652e-03, -2.5231e-03, -1.1253e-02,  1.6601e-02,
         -4.5058e-03,  4.3239e-03,  1.0291e-02,  2.0269e-02, -1.0369e-02],
        [-1.3724e-02, -8.1422e-04,  1.7627e-03, -5.2986e-03, -3.4685e-03,
         -2.0445e-03, -1.1192e-02,  1.0989e-02, -1.1012e-03, -1.7664e-02,
         -3.6856e-03,  9.6825e-03,  1.5568e-03,  8.0502e-03,  5.8132e-04,
         -9.6213e-03, -7.0347e-04,  5.8105e-03,  2.8272e-04,  1.1797e-02,
         -7.6907e-03,  2.2619e-02,  1.2308e-02,  1.8024e-02,  1.7888e-03,
         -1.7745e-03,  7.1993e-03, -3.5203e-04, -1.9954e-02, -1.8694e-02],
        [ 7.0645e-03,  8.1665e-04,  1.4387e-02,  1.2969e-02,  1.8918e-03,
         -5.8327e-03, -1.0239e-02,  8.8525e-03, -5.7654e-03,  1.7141e-03,
         -2.8454e-03,  1.3615e-02, -1.0575e-02, -3.0326e-03,  6.0094e-03,
         -1.0895e-02,  2.0619e-03,  1.5329e-03,  1.4229e-03, -3.2453e-03,
         -1.7779e-03,  9.6478e-03, -2.3171e-02,  1.4701e-02,  1.9202e-02,
          9.1253e-03, -6.1181e-03, -1.0587e-03,  2.0379e-04,  1.8809e-03],
        [ 1.2714e-02,  3.9214e-03,  5.0589e-04,  6.2322e-03,  1.3992e-02,
          1.4027e-03, -4.3150e-03,  1.0659e-03, -7.9712e-04, -1.9232e-02,
         -1.0893e-02,  2.5453e-02, -2.4113e-03, -8.1766e-03, -5.1438e-03,
         -1.5338e-02, -3.9863e-03,  5.6107e-03, -1.1374e-03,  5.2210e-03,
         -1.6869e-02, -5.0722e-03, -6.0378e-03,  3.9245e-05, -1.0497e-02,
          2.8167e-03,  1.5779e-02, -2.7614e-03, -2.5383e-03, -3.9343e-03],
        [ 6.0877e-03,  8.4987e-03,  1.3891e-02, -1.4258e-03, -8.7571e-03,
         -4.2930e-03, -9.5155e-04, -1.2515e-03, -2.7573e-03, -1.8363e-03,
          3.6556e-03,  1.8146e-03, -1.5488e-02,  4.5685e-04, -1.0984e-02,
         -3.8873e-03, -2.3433e-03, -8.7609e-03,  4.1699e-03,  9.4601e-03,
          3.1651e-04,  8.2427e-03,  2.3207e-03,  4.8457e-03, -1.5066e-02,
          6.2566e-03,  1.7090e-02,  9.2237e-03,  3.8094e-03, -3.4015e-03],
        [ 1.0200e-02,  1.2802e-02,  6.2122e-03, -1.1477e-02,  2.2691e-03,
         -4.6306e-04,  2.8861e-06,  9.0622e-03, -2.6586e-02, -7.3941e-04,
          6.5375e-07, -1.5064e-03,  2.4399e-02, -1.4029e-03, -4.1262e-03,
          1.2517e-02, -6.1972e-03, -8.5488e-03, -5.8397e-03,  1.9754e-03,
          4.5322e-03,  4.9351e-03, -8.4586e-03, -1.0668e-02, -4.0087e-04,
          7.3720e-03, -6.5349e-03, -1.7660e-03, -9.3602e-03,  1.5721e-03],
        [ 1.6188e-02, -2.0150e-03, -1.6816e-02, -1.6911e-02,  6.5990e-03,
         -1.6605e-02, -9.2304e-03, -8.1321e-03,  2.1806e-02, -7.1569e-03,
         -1.0543e-02, -6.5349e-03,  2.9192e-03, -1.5003e-02, -6.7353e-03,
          3.0836e-03, -1.0745e-04, -6.8194e-03,  1.0827e-02, -1.0363e-02,
         -1.2346e-02, -7.0676e-03,  1.2861e-02, -6.6271e-03, -5.7568e-03,
          3.4977e-03, -5.9773e-03,  3.5096e-03,  1.0743e-02,  8.9650e-03],
        [ 2.9357e-03,  7.0599e-03, -8.3739e-03, -3.4416e-03,  2.7688e-02,
         -1.0271e-02, -1.3291e-02, -9.5348e-03,  9.4741e-03,  5.1175e-03,
          9.3990e-03, -2.4245e-03, -7.4773e-03, -1.8540e-03, -1.7982e-02,
         -9.4906e-03,  1.2980e-02,  1.3263e-02,  7.2149e-04,  2.5788e-04,
         -2.8871e-04, -1.0207e-02,  4.8015e-03, -2.5008e-03, -1.0212e-02,
         -1.6521e-02, -1.0582e-02,  2.5365e-03,  2.0865e-03,  4.9698e-03],
        [ 4.8506e-03,  1.2074e-02, -1.1068e-02,  6.3328e-03, -7.2685e-03,
          8.6962e-03,  8.3850e-03,  3.9330e-04, -1.4435e-02,  2.6184e-03,
         -9.1552e-03, -9.1968e-03, -7.5484e-03,  1.7064e-02, -1.0926e-02,
          5.8421e-03,  3.2155e-03, -1.7801e-02, -7.1032e-03,  1.5225e-02,
          4.3300e-03, -1.1750e-02,  2.3955e-03,  5.8431e-03, -3.3767e-03,
          8.4297e-03,  5.9677e-03, -5.3130e-03, -6.2581e-03,  7.2684e-04],
        [ 1.0966e-03,  5.1526e-03,  1.2423e-02,  2.7335e-04, -2.2481e-03,
         -6.6638e-03,  1.0812e-02, -5.5681e-03, -3.6775e-03,  1.7190e-02,
         -8.9243e-03,  1.0826e-02,  2.1140e-02, -3.1026e-03, -7.8873e-03,
          1.1538e-02, -1.9044e-03,  1.0823e-02,  9.9094e-03,  2.5864e-03,
         -1.1288e-02,  7.8564e-03,  1.6430e-02, -5.0827e-03, -1.3557e-03,
          1.5250e-02, -7.7385e-03, -1.2449e-02,  1.0931e-02,  4.5603e-03],
        [ 1.0957e-02, -5.0123e-03, -9.2494e-03, -5.7311e-03,  1.0838e-02,
          3.2123e-03, -2.0542e-03, -7.1564e-03, -1.0734e-03,  1.9866e-02,
          1.6477e-02, -5.1313e-04,  3.8111e-03, -7.6825e-03, -6.2488e-03,
          5.7776e-03,  6.2000e-04, -1.7886e-02, -1.0113e-02,  1.1926e-03,
         -1.7908e-02, -9.1477e-03, -4.8008e-03, -4.7199e-03,  1.5774e-02,
          2.0051e-02, -1.7119e-02,  3.6467e-03, -2.3746e-03, -3.1700e-03],
        [-6.5938e-03, -8.9639e-03,  7.4318e-03, -7.2729e-03, -7.1252e-03,
         -2.4921e-03,  3.9307e-03,  3.5104e-03,  1.1045e-02, -5.6922e-03,
          4.3981e-03,  1.8079e-02, -4.0234e-03, -9.3731e-03,  2.8033e-03,
          1.2702e-02, -6.7108e-03,  1.1401e-03,  4.4380e-03,  3.1948e-02,
         -5.5021e-03,  1.4521e-03, -1.3206e-02, -2.4222e-03,  4.7926e-04,
         -1.1533e-02, -5.7481e-03,  1.2893e-02,  1.1545e-02,  5.2813e-03],
        [ 1.5414e-02, -6.9193e-03,  1.6852e-02, -1.0622e-02,  1.4744e-02,
          1.0154e-03,  1.3041e-02,  1.7373e-02, -3.9367e-03,  8.8151e-03,
         -1.1236e-02, -1.9134e-02, -2.9457e-04, -1.1307e-02,  5.1315e-04,
          1.5109e-02,  9.7158e-03, -2.8689e-03,  2.3455e-03,  1.0067e-02,
         -1.0586e-02, -8.3479e-03,  1.1931e-02,  8.1437e-04, -7.5972e-03,
         -1.2458e-03,  6.7771e-03,  1.0820e-02,  1.9163e-02,  1.5708e-02],
        [-9.7202e-03,  8.6719e-04,  2.0114e-02, -1.7551e-03,  1.0009e-02,
          2.3814e-06, -4.0599e-03, -1.6323e-03, -5.8922e-04, -1.9179e-02,
         -6.1471e-03,  3.2786e-03, -8.2910e-03,  1.5821e-02, -6.7673e-04,
          6.7152e-03, -9.1754e-03, -1.1328e-02, -2.6055e-03,  5.0403e-03,
          1.2042e-03, -3.8331e-03, -1.4951e-02, -1.3479e-02,  1.1225e-03,
          8.5821e-04, -6.5184e-03, -1.4773e-02, -5.9137e-03,  3.5191e-04],
        [ 1.8287e-02,  6.3982e-03,  9.5072e-05, -8.5228e-03,  3.8553e-03,
         -1.4717e-02, -2.2549e-02, -9.0621e-04, -1.2570e-02,  1.1718e-02,
         -6.6394e-03,  4.0225e-03,  1.1599e-02, -1.2683e-02, -1.6592e-02,
          6.4644e-03, -3.6911e-04, -1.5643e-02,  5.2307e-03, -6.2249e-03,
         -7.1715e-03, -2.3126e-03, -8.4576e-04,  3.8474e-03,  1.0998e-02,
         -4.6572e-03,  5.6035e-03,  3.4481e-02,  1.2175e-02,  5.7834e-03],
        [ 8.1615e-03,  1.1394e-02,  6.7657e-03,  6.2123e-03, -2.5495e-02,
          8.0885e-03, -1.4886e-02,  9.0453e-03,  7.6994e-04,  3.8311e-03,
          3.3331e-03, -1.1154e-03, -1.5299e-03, -1.0464e-04, -1.7200e-02,
          1.7913e-03,  6.0317e-03, -9.6406e-03, -3.2287e-03, -8.2056e-03,
         -3.6062e-03,  1.4878e-02,  2.8893e-04,  5.8158e-03,  2.6536e-03,
         -2.4103e-02, -1.4993e-02, -1.4791e-03, -5.8795e-06,  3.5734e-03],
        [-1.9504e-02, -5.2369e-03,  4.3967e-03, -8.9265e-03, -5.5365e-03,
         -3.4119e-03, -9.1115e-03,  1.0674e-02,  2.1772e-03,  1.2874e-02,
          4.4263e-03,  7.4545e-03, -2.3560e-03,  2.8753e-03, -1.1261e-02,
          2.3133e-02,  4.8261e-03,  8.4998e-03, -9.4744e-03,  2.2099e-02,
         -4.8882e-03, -5.5454e-03,  1.4507e-02, -6.0991e-03,  1.9609e-02,
          1.0477e-03, -1.3490e-02,  1.1814e-03, -8.1444e-03,  1.8736e-02],
        [ 2.4450e-02, -1.7491e-03, -5.5336e-03,  3.9693e-03, -9.6735e-04,
          3.7527e-03,  6.2744e-03, -2.6313e-03,  1.5444e-02, -4.2243e-03,
         -8.7079e-03,  2.0071e-02,  1.2334e-03, -8.3905e-03, -1.6589e-02,
          3.5387e-03, -3.9648e-03,  8.2130e-04,  1.4307e-02, -4.3372e-03,
          3.1793e-04, -7.4976e-03, -7.1900e-03,  1.2343e-03, -6.7502e-03,
         -1.1686e-02,  7.9501e-03, -2.0728e-02,  5.3000e-03,  2.5626e-02]])
2.linear.weight tensor([[-5.5996e-03, -1.7881e-02,  2.9029e-03, -3.8814e-04,  4.6640e-03,
          1.4729e-03, -8.7504e-04,  1.7051e-03, -9.9808e-03, -5.0457e-03,
          1.1433e-02,  2.6712e-03,  6.0539e-04, -1.9261e-02,  3.1117e-03,
          4.8112e-03, -7.1367e-03, -8.3227e-03, -8.6805e-03, -1.3699e-02],
        [ 9.2634e-03, -4.6123e-04,  8.3162e-05,  4.1494e-03,  5.1273e-03,
         -1.6681e-03,  1.4864e-02, -3.0277e-04,  1.7692e-03, -9.4125e-03,
          3.1612e-03, -1.3875e-02, -1.1496e-03,  1.5235e-02, -1.0851e-03,
          8.6537e-03,  9.0799e-03, -2.8304e-03,  7.5469e-03, -7.0789e-03],
        [-5.3225e-03, -8.5825e-03,  5.8378e-03,  1.1406e-03, -3.2881e-03,
          7.9792e-03,  1.5086e-02, -8.4687e-03, -1.3789e-02, -1.8288e-02,
          1.3957e-03,  2.2536e-03, -3.3606e-03,  3.2420e-03, -9.3056e-03,
          1.5156e-02,  6.4650e-03, -1.3013e-03,  1.0178e-02,  1.3913e-02],
        [-1.2846e-02,  4.8990e-04,  7.1575e-03,  1.3793e-03, -5.9738e-03,
          3.9912e-03,  2.5881e-03, -1.5927e-02, -9.1614e-03, -1.6528e-02,
          1.1311e-02, -4.5142e-03,  4.0138e-03, -2.2480e-02, -3.1659e-03,
         -1.8865e-02,  1.4666e-02,  1.0195e-02, -9.0439e-03,  3.8127e-03],
        [ 1.2961e-02,  7.7461e-03,  1.1123e-02,  8.3137e-03, -1.2393e-02,
         -3.1891e-03,  4.7997e-03,  1.3994e-02,  6.5888e-03,  1.6503e-02,
         -8.6732e-05,  1.0505e-02, -5.6654e-03, -5.1383e-03, -6.2080e-03,
          1.7060e-03,  1.0218e-02, -5.8861e-03,  1.1582e-02,  2.7832e-03],
        [ 7.2542e-03,  2.7732e-03,  6.6612e-03,  3.4965e-03, -4.0417e-03,
         -7.7164e-03, -1.9723e-02, -1.8491e-03,  9.4730e-03, -1.7118e-04,
          1.1873e-02, -7.9715e-03,  1.0753e-03,  1.0804e-02,  1.3700e-02,
          1.6791e-02, -1.4478e-02,  1.7035e-02,  9.3034e-04,  6.3716e-03],
        [-5.1573e-03, -1.3071e-03,  1.5302e-02, -6.1851e-03,  6.7677e-04,
          1.4603e-02,  7.0722e-03,  2.6819e-03,  1.5753e-02, -3.4153e-03,
         -2.0578e-03, -8.0549e-03, -1.2534e-02,  7.2134e-03, -1.6519e-02,
          7.4687e-04,  4.4496e-03, -2.0857e-03, -1.4627e-02,  7.0603e-03],
        [-3.0296e-04,  1.3288e-03, -7.3660e-03, -9.3445e-04,  7.9644e-03,
          4.3456e-03, -1.3778e-02,  4.5655e-03, -1.2824e-02, -1.1208e-02,
          1.8958e-02,  2.4357e-03, -2.1795e-02,  5.2504e-03,  8.8114e-04,
          2.6217e-03, -5.6610e-03,  3.5770e-04,  9.2574e-03,  7.0462e-04],
        [ 5.5887e-03, -2.3937e-02,  1.3323e-02,  1.2788e-02, -1.2492e-02,
         -8.2638e-03,  3.7073e-04, -6.0787e-03,  4.5608e-03, -2.0334e-02,
         -1.5575e-02, -7.6481e-03,  1.2293e-02, -4.0633e-03, -4.2219e-03,
         -8.4397e-03, -7.8156e-03,  1.4628e-02,  8.0971e-04,  4.0196e-04],
        [-1.3949e-02,  3.2856e-03,  3.0923e-03, -1.2260e-04, -1.1090e-02,
         -1.0921e-02,  1.2069e-02, -1.4299e-02,  6.6826e-03, -5.0484e-03,
         -1.6159e-02,  5.4746e-03,  4.6988e-03, -2.0970e-02, -1.2481e-03,
         -9.0344e-03, -1.4410e-02, -3.9734e-03,  4.2471e-03,  1.3298e-02],
        [-2.9720e-03,  1.4633e-02,  3.5870e-03,  2.0081e-03, -8.4944e-04,
         -1.0992e-02,  1.3077e-02, -1.0255e-03,  4.1050e-05, -9.8274e-03,
         -4.1936e-04,  1.1226e-02, -1.2842e-02,  1.6335e-02, -8.9160e-03,
          4.1783e-03,  3.9959e-03, -4.7716e-03, -3.1262e-03, -1.0138e-02],
        [ 2.0540e-02, -2.0969e-03, -3.8366e-04, -2.7185e-03, -3.2523e-03,
          7.9718e-03,  8.1142e-03,  3.6300e-03, -7.4875e-03,  1.3166e-02,
          1.3331e-03,  1.0332e-02,  1.1699e-02,  1.0125e-02,  1.0621e-02,
         -1.6513e-02, -3.5310e-03, -1.9825e-02, -8.0858e-03, -3.1577e-03],
        [ 9.9199e-03, -1.4978e-03, -2.9754e-03, -1.2934e-02, -1.4795e-02,
         -1.6737e-02, -6.4125e-03, -1.3493e-03, -6.9752e-03,  1.2560e-02,
         -4.3187e-03, -1.4194e-02,  3.6981e-04, -2.8193e-03,  8.1851e-03,
         -1.1203e-02,  1.9439e-02,  1.0795e-02, -8.2543e-03,  2.3599e-03],
        [-1.1758e-02,  9.5521e-03,  6.3462e-04,  3.5751e-03,  1.3527e-02,
          9.1637e-03, -4.8209e-03, -1.0592e-02,  2.0195e-03,  1.6119e-02,
          1.1654e-02, -1.0050e-02, -7.9837e-03,  1.9223e-02,  2.3013e-03,
          1.4753e-03,  6.0071e-03,  4.8298e-04,  6.6951e-03, -1.2826e-02],
        [-2.0467e-02, -4.2140e-03,  6.8931e-03, -5.0351e-03, -1.0566e-02,
         -6.9593e-03,  4.7499e-03,  8.2981e-03,  2.5675e-02,  1.5285e-02,
         -9.5692e-03, -4.5086e-03, -5.6572e-03, -1.1610e-03, -1.7814e-03,
          1.2086e-02,  9.0804e-03, -9.1113e-03, -1.9917e-02, -1.6530e-02],
        [-8.3130e-03,  1.3344e-02, -1.0923e-02, -2.2851e-02,  6.8411e-03,
          2.1356e-03,  8.5286e-03, -5.2713e-03, -1.2193e-02, -1.1948e-03,
          3.3531e-04, -4.8021e-03,  5.7277e-03,  1.1665e-02, -1.0539e-04,
          7.7064e-03,  1.4751e-02, -4.4743e-03,  1.4951e-02,  1.2392e-02],
        [-1.3800e-02,  1.3546e-02, -6.8238e-03,  1.4270e-02, -1.2051e-02,
          5.9353e-03,  2.4619e-03,  1.5746e-03,  7.2940e-03, -9.0082e-04,
         -5.7518e-03, -2.0381e-02, -7.1935e-03,  2.1555e-03, -5.1282e-03,
         -1.9255e-02, -7.8506e-03, -1.4166e-02,  3.6891e-03, -6.8586e-03],
        [-1.3660e-02,  2.2254e-02,  1.7642e-02, -3.0450e-03,  1.3232e-02,
         -1.0315e-02,  1.0142e-02,  3.5027e-04, -5.0814e-04, -2.6084e-03,
          2.5116e-02,  5.9374e-03,  5.4630e-03,  1.1566e-02, -8.3270e-03,
         -7.3726e-03,  1.1869e-02, -1.8963e-02,  5.7513e-03, -8.7186e-03],
        [-2.0966e-03, -1.7567e-03,  1.4226e-03, -6.6999e-03,  4.1837e-03,
         -7.0836e-04,  5.6019e-03, -6.0745e-03,  7.3432e-03,  6.0721e-03,
          5.1656e-03, -1.8158e-03, -8.8598e-03,  2.4378e-02,  4.7904e-03,
          7.9625e-03,  5.9612e-03, -2.0552e-03, -4.8240e-03, -1.4750e-02],
        [ 7.3819e-03, -6.9134e-03,  8.1345e-03, -1.4451e-02,  7.6821e-03,
         -1.5124e-03,  3.8898e-03,  1.2265e-03, -9.0760e-03, -1.5998e-02,
          2.7463e-04, -6.5673e-03, -7.7155e-04, -2.0952e-02, -2.3035e-03,
         -1.2220e-03,  1.2617e-02,  5.1720e-03, -1.7524e-02,  5.6977e-03]])
for name, param in net.named_parameters():
    if 'bias' in name:
        init.constant_(param, val=0)
        print(name, param.data)
0.net.0.bias tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0.])
1.bias tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
2.linear.bias tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

自定义初始化方法

有时候我们需要的初始化方法并没有在init模块中提供。这时,可以实现一个初始化方法,从而能够像使用其他初始化方法那样使用它。在这之前我们先来看看PyTorch是怎么实现这些初始化方法的,例如torch.nn.init.normal_:

def normal_(tensor, mean=0, std=1):
    with torch.no_grad():
        return tensor.normal_(mean, std)

可以看到这就是一个inplace改变Tensor值的函数,而且这个过程是不记录梯度的。 类似的我们来实现一个自定义的初始化方法。在下面的例子里,我们令权重有一半概率初始化为0,有另一半概率初始化为[−10,−5][−10,−5]和[5,10][5,10]两个区间里均匀分布的随机数。

def init_weight_(tensor):
    with torch.no_grad():
        tensor.uniform_(-10, 10)
        tensor *= (tensor.abs() >= 5).float()

for name, param in net.named_parameters():
    if 'weight' in name:
        init_weight_(param)
        print(name, param.data)
0.net.0.weight tensor([[-0.0000, -6.2573, -9.6670,  ..., -6.2055, -0.0000,  0.0000],
        [ 0.0000, -6.1450, -0.0000,  ...,  9.4352,  8.5185,  0.0000],
        [ 0.0000,  0.0000, -0.0000,  ...,  0.0000, -0.0000, -5.4052],
        ...,
        [-0.0000, -9.5803, -5.8481,  ..., -0.0000, -0.0000, -0.0000],
        [-9.8709,  0.0000, -0.0000,  ...,  0.0000,  7.7648, -7.7570],
        [-8.7503, -9.4466,  8.8122,  ...,  5.5594, -0.0000,  9.3528]])
1.weight tensor([[-5.3031, -7.1580, -9.9964, -6.4769,  0.0000, -5.8270,  5.7411, -5.8713,
         -0.0000, -0.0000, -0.0000,  5.3456,  0.0000,  0.0000, -0.0000,  6.5026,
          0.0000, -0.0000,  0.0000,  0.0000, -6.3099,  0.0000, -0.0000,  0.0000,
         -7.5503, -0.0000,  9.9910, -0.0000, -0.0000,  6.3966],
        [ 5.5881,  0.0000, -0.0000, -8.1178,  9.6688, -0.0000,  0.0000,  8.6963,
          0.0000,  8.0869,  0.0000,  6.0899,  8.3442, -6.8183, -8.3160, -0.0000,
         -0.0000, -7.4499, -6.6586, -0.0000,  0.0000, -6.7071, -0.0000,  7.4049,
          0.0000, -9.0628, -0.0000,  7.3452,  0.0000, -8.3752],
        [-6.3758,  0.0000,  0.0000,  9.1926,  7.5408,  0.0000, -9.9305,  0.0000,
         -0.0000, -9.0137,  0.0000,  8.7384, -0.0000, -6.3803, -0.0000,  0.0000,
         -8.8573,  8.8939,  0.0000,  6.2515,  5.3137,  7.6621, -0.0000, -8.0161,
         -8.4190,  6.1407, -0.0000, -6.5521,  5.4117, -0.0000],
        [ 0.0000, -9.7822,  0.0000, -0.0000,  0.0000, -9.9513,  0.0000, -0.0000,
          9.5243,  6.7586, -6.7599, -7.2046,  5.7396,  0.0000, -7.4426,  6.5923,
          5.4310,  0.0000,  9.5645, -0.0000,  9.5840, -7.3611, -0.0000,  5.2647,
         -5.9115, -0.0000, -0.0000, -0.0000, -5.2489, -8.3718],
        [ 5.0601,  8.8166,  0.0000,  6.9072, -8.9312,  0.0000,  7.9705, -9.0079,
          6.0310, -9.9306,  0.0000,  5.2559,  0.0000, -0.0000,  0.0000,  6.6453,
         -8.2131, -0.0000, -9.7696, -0.0000, -6.1296, -6.1960,  0.0000, -7.1193,
          6.0872,  8.9191, -7.5285,  8.8542,  0.0000, -0.0000],
        [ 5.7640, -8.2749,  0.0000, -0.0000, -0.0000,  8.4516, -0.0000,  8.1646,
         -7.5594, -0.0000, -7.6956,  0.0000,  0.0000,  8.4151,  0.0000, -0.0000,
         -9.5694, -6.6565,  5.6100, -7.5257, -0.0000,  7.6025,  6.7463, -9.1704,
         -0.0000,  0.0000,  8.9362,  7.7404, -9.6161,  0.0000],
        [-5.9001,  0.0000, -0.0000, -7.5054,  0.0000, -7.8240,  7.8767,  7.6504,
          9.0638,  7.9266,  7.0275, -0.0000, -6.6694,  5.2457,  9.6615,  0.0000,
          5.6970, -7.9317, -0.0000,  0.0000,  6.3138, -0.0000, -0.0000,  6.8197,
         -0.0000, -5.5665,  5.6393,  6.9409,  8.4405, -6.9706],
        [ 0.0000, -6.5044, -5.4181,  0.0000, -8.2382,  6.3029, -0.0000, -0.0000,
          0.0000, -0.0000, -8.7580,  6.9954,  8.0429, -9.6811,  0.0000,  7.1997,
          9.6289, -6.6778,  8.3476,  0.0000,  0.0000, -5.1837, -7.5868,  8.8954,
          0.0000,  9.4241, -5.2790,  9.6536, -6.5575, -7.8887],
        [ 8.6683, -9.6715,  0.0000,  5.7765,  6.6276,  8.2964, -0.0000, -6.1496,
          8.9688, -0.0000, -0.0000,  8.1510, -0.0000,  6.2618,  0.0000, -7.1139,
         -9.1306,  8.7948, -8.2000,  0.0000, -7.7904, -0.0000,  8.5050, -7.8431,
         -0.0000, -0.0000,  7.9535, -6.2632, -6.1001, -6.1616],
        [-8.8006, -8.2721, -6.8235,  0.0000, -9.7454,  0.0000, -0.0000,  0.0000,
          6.3692,  9.8793,  0.0000,  0.0000, -7.2012,  0.0000,  0.0000,  0.0000,
         -0.0000,  0.0000,  8.0920, -0.0000, -9.2214, -0.0000, -8.5031, -8.3191,
         -0.0000,  7.7950, -0.0000,  0.0000,  0.0000, -7.7035],
        [ 7.3713,  6.2752,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  9.8827,
         -9.9048, -0.0000,  0.0000,  8.6781,  7.9049, -0.0000, -0.0000, -0.0000,
          0.0000, -0.0000,  0.0000,  5.6662, -5.7249,  7.5572,  7.9236, -8.0282,
          0.0000, -0.0000, -8.1768,  0.0000, -0.0000, -0.0000],
        [-9.1349,  8.1316,  0.0000, -0.0000,  0.0000, -0.0000,  0.0000, -0.0000,
          6.6178, -0.0000, -0.0000,  0.0000,  7.0107,  8.4189, -0.0000,  5.0933,
          9.8502, -0.0000, -9.1036,  0.0000, -5.6974,  0.0000, -9.8530, -8.0555,
          0.0000, -5.2797, -7.7184, -0.0000, -5.4591,  8.9083],
        [-0.0000,  7.5084,  0.0000, -7.2532,  0.0000, -0.0000,  0.0000,  8.1774,
         -9.4518, -5.2770, -0.0000, -0.0000,  5.0882,  7.4918, -8.7120, -0.0000,
         -7.0417,  0.0000, -0.0000,  5.1748,  0.0000, -0.0000, -0.0000, -9.3591,
          0.0000, -8.2938, -8.9266, -0.0000,  0.0000,  0.0000],
        [ 9.5399,  0.0000,  7.4996, -0.0000,  9.8394,  0.0000,  0.0000, -0.0000,
          0.0000, -0.0000,  0.0000,  0.0000, -5.3790, -9.6425,  9.8764,  0.0000,
          7.0963, -8.9114, -0.0000,  9.0510,  5.4095, -5.0019,  0.0000, -7.0848,
         -0.0000,  0.0000,  0.0000, -0.0000,  0.0000, -9.8849],
        [ 8.1293,  0.0000, -6.7563,  0.0000,  7.6843,  8.4486,  6.6997,  7.9062,
         -5.6233,  0.0000,  6.2966,  7.4003,  0.0000, -0.0000,  8.2519,  0.0000,
          8.9715, -8.4958, -7.0926, -0.0000, -0.0000, -5.9202,  6.4348,  8.4848,
          0.0000, -7.7669, -7.1864,  0.0000,  9.1471, -0.0000],
        [ 0.0000,  8.9136, -6.4488,  0.0000, -0.0000,  0.0000, -0.0000,  0.0000,
          9.3667,  7.9826,  0.0000,  0.0000, -6.9993,  0.0000,  9.6924,  0.0000,
          8.0840,  9.4185,  0.0000,  6.2945,  7.2634, -0.0000, -6.5016, -0.0000,
          0.0000, -0.0000,  6.4698, -0.0000,  5.2106, -0.0000],
        [-0.0000,  6.1383, -9.1950, -5.4267, -0.0000, -8.8607,  9.7967,  0.0000,
          0.0000,  0.0000,  5.7624, -9.8647,  7.0602, -9.4853, -0.0000,  9.4742,
         -5.6148, -6.2168,  0.0000,  9.9788, -9.1114,  0.0000,  0.0000, -0.0000,
          6.7838,  0.0000,  0.0000, -0.0000,  6.0502, -5.0156],
        [ 8.1110, -0.0000,  6.7124, -6.7490,  5.9337, -8.4773, -0.0000, -7.2125,
          0.0000, -0.0000,  5.2543,  0.0000,  0.0000,  9.5925,  0.0000, -0.0000,
         -0.0000, -0.0000, -8.4826,  0.0000, -0.0000, -6.4074,  0.0000, -0.0000,
          7.2494, -8.8358,  7.8122,  8.8839,  7.9619,  7.8244],
        [ 0.0000, -6.1305,  0.0000,  6.7742,  5.9301, -0.0000,  0.0000, -9.8249,
          7.4770,  0.0000,  6.4683, -0.0000,  0.0000, -0.0000, -7.5854, -0.0000,
         -9.8200, -8.4659, -0.0000, -0.0000,  0.0000,  8.8622,  8.4771,  5.8597,
          0.0000,  0.0000,  6.5139,  0.0000, -6.5754, -0.0000],
        [-0.0000,  8.0858, -0.0000,  0.0000, -0.0000,  9.6130, -0.0000, -5.2527,
         -9.6432, -5.0170,  5.0056,  7.9539,  0.0000, -0.0000, -8.1660, -0.0000,
          0.0000, -0.0000,  0.0000, -7.4998,  0.0000, -0.0000, -0.0000,  9.5478,
          8.9482, -0.0000,  0.0000, -6.8827,  8.9620, -9.7837]])
2.linear.weight tensor([[ 0.0000, -6.9812,  6.5521, -0.0000, -9.5926, -0.0000, -0.0000, -5.0386,
          6.5204,  0.0000,  8.5534,  9.0793,  0.0000, -0.0000, -0.0000, -5.9300,
         -5.9106, -0.0000,  8.0045,  7.3943],
        [-5.1885, -0.0000, -0.0000,  5.7570,  0.0000, -0.0000, -8.7268,  9.2548,
          0.0000, -7.7151, -6.8379, -8.5315, -0.0000,  0.0000,  6.4908,  0.0000,
         -0.0000, -0.0000, -5.9161, -5.8564],
        [ 0.0000, -9.4816, -9.4379, -0.0000, -9.3362, -0.0000,  0.0000, -8.2363,
         -0.0000,  8.0405,  8.9864,  0.0000, -9.2516, -6.2904,  5.7813,  5.3719,
         -0.0000, -8.8506,  0.0000, -9.7931],
        [ 9.4502, -9.2087,  7.3578, -0.0000, -0.0000, -5.6518,  6.7095,  0.0000,
         -0.0000, -0.0000,  5.9396,  7.8704,  0.0000, -0.0000, -7.2353,  0.0000,
         -0.0000,  7.0156,  0.0000, -5.0929],
        [ 0.0000, -9.9551, -0.0000,  6.6445,  7.0596, -0.0000, -0.0000, -7.0890,
         -0.0000, -0.0000,  0.0000,  8.3706,  0.0000, -7.1886,  0.0000, -8.4613,
         -0.0000, -7.1553,  0.0000, -5.9845],
        [-0.0000,  7.2888, -0.0000, -0.0000,  0.0000, -9.2040,  5.5700, -7.6277,
          0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -8.3413,  5.5270,  0.0000,
          7.4749, -0.0000, -7.7825, -6.2418],
        [ 0.0000, -0.0000, -9.5633,  0.0000,  7.8421,  0.0000, -0.0000, -0.0000,
          0.0000,  7.5993,  7.9292,  7.3422,  9.5524, -5.4055, -0.0000,  7.3166,
         -0.0000,  0.0000,  9.8042,  6.4248],
        [ 0.0000, -0.0000,  9.5345,  9.1690, -0.0000, -0.0000,  5.9536, -6.3916,
          0.0000,  8.0737, -0.0000,  9.2005, -0.0000,  0.0000, -0.0000,  0.0000,
          0.0000,  0.0000, -0.0000,  5.8255],
        [-0.0000, -0.0000,  0.0000,  0.0000, -8.6615,  0.0000, -6.7372, -8.7178,
         -7.8605,  8.0300, -9.7292, -0.0000, -6.1143,  9.4508,  0.0000, -0.0000,
          0.0000,  8.0177, -7.3347,  0.0000],
        [ 6.3698,  0.0000, -0.0000, -0.0000, -7.8566, -0.0000,  0.0000, -0.0000,
          9.4590, -8.1037,  0.0000, -6.8476, -5.9440, -0.0000,  0.0000, -7.6990,
          0.0000, -8.9955, -0.0000, -0.0000],
        [-0.0000,  7.5792,  0.0000, -0.0000,  0.0000, -9.8703,  9.7579, -8.2181,
         -0.0000, -6.6664,  9.2518, -0.0000,  0.0000,  6.6096, -9.9753,  0.0000,
          7.4476, -8.6817, -7.4076, -6.4940],
        [-8.1521,  9.0149,  6.9163, -0.0000,  0.0000, -0.0000,  7.4414, -0.0000,
         -5.4136,  0.0000,  0.0000,  0.0000,  5.6016,  8.2639,  0.0000, -0.0000,
          8.3907,  5.6836, -7.3025,  6.4559],
        [-0.0000, -0.0000, -6.3013,  0.0000,  0.0000,  0.0000,  7.3791, -0.0000,
          9.9710,  0.0000, -6.0639,  0.0000, -5.8077,  5.1654,  0.0000,  7.9023,
         -0.0000, -0.0000, -0.0000, -8.7207],
        [ 5.5755,  0.0000, -6.9844,  7.6312, -0.0000,  0.0000,  8.0948,  0.0000,
          0.0000,  0.0000, -0.0000,  0.0000, -0.0000, -0.0000, -0.0000, -7.1881,
         -9.8673, -0.0000,  6.4959,  5.2270],
        [ 0.0000,  0.0000, -9.7369, -8.0855,  0.0000,  7.7631,  0.0000, -6.6331,
         -0.0000,  0.0000, -0.0000, -8.9024, -0.0000,  7.6198, -0.0000,  5.1892,
          0.0000,  8.1393,  0.0000,  9.0153],
        [-0.0000, -7.5369,  7.0477, -0.0000,  0.0000,  0.0000,  0.0000, -0.0000,
         -6.6719, -9.6702, -0.0000,  6.2202,  5.0695, -0.0000,  0.0000, -5.9180,
         -0.0000,  0.0000,  9.3806,  5.2014],
        [-0.0000,  5.1159, -7.1775, -0.0000, -5.8047,  7.2177, -7.7319, -0.0000,
         -0.0000, -9.0223, -0.0000,  0.0000,  0.0000, -0.0000,  8.1462, -8.1729,
          6.9001, -0.0000, -0.0000,  8.9680],
        [-0.0000,  0.0000, -0.0000,  0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
          8.0092, -8.0028,  0.0000, -0.0000,  7.1211, -0.0000,  0.0000,  0.0000,
          9.3772, -0.0000,  0.0000, -0.0000],
        [ 0.0000,  5.9549,  5.3347, -0.0000,  0.0000,  6.8372, -0.0000,  7.0890,
          6.8374,  6.8444,  5.3985,  5.1464, -0.0000, -0.0000,  8.7471,  6.0348,
         -0.0000,  0.0000,  6.7883,  8.7566],
        [ 7.5425,  0.0000, -0.0000, -6.2917,  9.6904, -6.8175, -9.9985,  0.0000,
          5.8819,  8.6874,  0.0000,  6.9018,  0.0000, -7.0253, -0.0000,  0.0000,
          7.5909,  0.0000,  0.0000, -0.0000]])

参考2.3.2节,我们还可以通过改变这些参数的data来改写模型参数值同时不会影响梯度:

for name, param in net.named_parameters():
    if 'bias' in name:
        param.data += 1
        print(name, param.data)
        
0.net.0.bias tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
1.bias tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1.])
2.linear.bias tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1.])

共享模型参数

在有些情况下,我们希望在多个层之间共享模型参数。4.1.3节提到了如何共享模型参数: Module类的forward函数里多次调用同一个层。此外,如果我们传入Sequential的模块是同一个Module实例的话参数也是共享的,下面来看一个例子:

linear = nn.Linear(1, 1, bias=False)
net = nn.Sequential(linear, linear) 
print(net)
for name, param in net.named_parameters():
    init.constant_(param, val=3)
    print(name, param.data)
Sequential(
  (0): Linear(in_features=1, out_features=1, bias=False)
  (1): Linear(in_features=1, out_features=1, bias=False)
)
0.weight tensor([[3.]])

在内存中,这两个线性层其实一个对象:

print(id(net[0]) == id(net[1]))
print(id(net[0].weight) == id(net[1].weight))
True
True

因为模型参数里包含了梯度,所以在反向传播计算时,这些共享的参数的梯度是累加的:

x = torch.ones(1, 1)
y = net(x).sum()
print(y)
y.backward()
print(net[0].weight.grad) # 单次梯度是3,两次所以就是6
tensor(9., grad_fn=<SumBackward0>)
tensor([[6.]])

自定义层

不含模型参数的自定义层

我们先介绍如何定义一个不含模型参数的自定义层。事实上,这和4.1节(模型构造)中介绍的使用Module类构造模型类似。下面的CenteredLayer类通过继承Module类自定义了一个将输入减掉均值后输出的层,并将层的计算定义在了forward函数里。这个层里不含模型参数。

import torch 
from torch import nn
class CenteredLayer(nn.Module):
    def __init__(self,**kwargs):
        super(CenteredLayer,self).__init__(**kwargs)
    def forward(self,x):
        return x-x.mean()
layer = CenteredLayer()
layer(torch.tensor([1,2,3,4,5],dtype = torch.float))
tensor([-2., -1.,  0.,  1.,  2.])
net = nn.Sequential(nn.Linear(8,128,CenteredLayer()))
y = net(torch.rand(4,8))
y.mean().item()
0.005987975746393204

含模型参数的自定义层

在4.2节(模型参数的访问、初始化和共享)中介绍了Parameter类其实是Tensor的子类,如果一个Tensor是Parameter,那么它会自动被添加到模型的参数列表里。所以在自定义含模型参数的层时,我们应该将参数定义成Parameter,除了像4.2.1节那样直接定义成Parameter类外,还可以使用ParameterList和ParameterDict分别定义参数的列表和字典。

ParameterList接收一个Parameter实例的列表作为输入然后得到一个参数列表,使用的时候可以用索引来访问某个参数,另外也可以使用append和extend在列表后面新增参数。

class MyDense(nn.Module):
    def __init__(self):
        super(MyDense,self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(4,4)) for i in range(3)])
        self.params.append(nn.Parameter(torch.randn(4,1)))
    def forward(self,x):
        for i in range(len(self.params)):
            x = torch.mm(x,self.params[i])
        return x
net = MyDense()
print(net)
MyDense(
  (params): ParameterList(
      (0): Parameter containing: [torch.FloatTensor of size 4x4]
      (1): Parameter containing: [torch.FloatTensor of size 4x4]
      (2): Parameter containing: [torch.FloatTensor of size 4x4]
      (3): Parameter containing: [torch.FloatTensor of size 4x1]
  )
)

而ParameterDict接收一个Parameter实例的字典作为输入然后得到一个参数字典,然后可以按照字典的规则使用了。例如使用update()新增参数,使用keys()返回所有键值,使用items()返回所有键值对等等,可参考官方文档。

class MyDictDense(nn.Module):
    def __init__(self):
        super(MyDictDense, self).__init__()
        self.params = nn.ParameterDict({
                'linear1': nn.Parameter(torch.randn(4, 4)),
                'linear2': nn.Parameter(torch.randn(4, 1))
        })
        self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))}) # 新增

    def forward(self, x, choice='linear1'):
        return torch.mm(x, self.params[choice])

net = MyDictDense()
print(net)

MyDictDense(
  (params): ParameterDict(
      (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
      (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
      (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
  )
)

这样就可以根据传入的键值来进行不同的前向传播:

x = torch.ones(1, 4)
print(net(x, 'linear1'))
print(net(x, 'linear2'))
print(net(x, 'linear3'))
tensor([[ 5.4923,  0.4334, -1.8005, -0.0718]], grad_fn=<MmBackward>)
tensor([[2.9095]], grad_fn=<MmBackward>)
tensor([[ 1.0343, -3.6477]], grad_fn=<MmBackward>)

我们也可以使用自定义层构造模型。它和PyTorch的其他层在使用上很类似。

net = nn.Sequential(
    MyDictDense(),
    MyDense(),
)
print(net)
print(net(x))
Sequential(
  (0): MyDictDense(
    (params): ParameterDict(
        (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
        (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
        (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
    )
  )
  (1): MyDense(
    (params): ParameterList(
        (0): Parameter containing: [torch.FloatTensor of size 4x4]
        (1): Parameter containing: [torch.FloatTensor of size 4x4]
        (2): Parameter containing: [torch.FloatTensor of size 4x4]
        (3): Parameter containing: [torch.FloatTensor of size 4x1]
    )
  )
)
tensor([[-8.2394]], grad_fn=<MmBackward>)

读取与存储

读写Tensor

我们可以直接使用save函数和load函数分别存储和读取Tensor。save使用Python的pickle实用程序将对象进行序列化,然后将序列化的对象保存到disk,使用save可以保存各种对象,包括模型、张量和字典等。而load使用pickle unpickle工具将pickle的对象文件反序列化为内存。

下面的例子创建了Tensor变量x,并将其存在文件名同为x.pt的文件里。

import torch
from torch import nn
x= torch.ones(3)
print(x)
torch.save(x,'x.pt')
tensor([1., 1., 1.])
x2=torch.load('x.pt')
x2
tensor([1., 1., 1.])

还可以存储一个Tensor列表并读回内存。

y = torch.zeros(4)
torch.save([x,y],'xy.pt')
xy_list = torch.load('xy.pt')
xy_list
[tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]

存储并读取一个从字符串映射到Tensor的字典。

torch.save({'x': x, 'y': y}, 'xy_dict.pt')
xy = torch.load('xy_dict.pt')
xy
{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}

读写模型

在PyTorch中,Module的可学习参数(即权重和偏差),模块模型包含在参数中(通过model.parameters()访问)。state_dict是一个从参数名称隐射到参数Tesnor的字典对象。

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.hidden = nn.Linear(3, 2)
        self.act = nn.ReLU()
        self.output = nn.Linear(2, 1)

    def forward(self, x):
        a = self.act(self.hidden(x))
        return self.output(a)

net = MLP()
net.state_dict()
OrderedDict([('hidden.weight', tensor([[ 0.3508, -0.0527,  0.5097],
                      [ 0.1291,  0.4404,  0.2258]])),
             ('hidden.bias', tensor([0.0332, 0.4611])),
             ('output.weight', tensor([[0.1315, 0.0635]])),
             ('output.bias', tensor([-0.4548]))])
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer.state_dict()
{'param_groups': [{'dampening': 0,
   'lr': 0.001,
   'momentum': 0.9,
   'nesterov': False,
   'params': [2501567222120, 2501567222480, 2501566216088, 2501565231320],
   'weight_decay': 0}],
 'state': {}}

仅保存和加载模型参数(state_dict)

torch.save(model.state_dict(),PATH)  # 推荐的文件后缀名是pt或pth

#加载
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
---------------------------------------------------------------------------

NameError                                 Traceback (most recent call last)

<ipython-input-55-60ddd59ce579> in <module>()
----> 1 torch.save(model.state_dict(),PATH)  # 推荐的文件后缀名是pt或pth
      2 
      3 #加载
      4 model = TheModelClass(*args, **kwargs)
      5 model.load_state_dict(torch.load(PATH))


NameError: name 'model' is not defined

保存和加载整个模型

torch.save(model, PATH)

#加载

model = torch.load(PATH)
X = torch.randn(2, 3)
Y = net(X)

PATH = "./net.pt"
torch.save(net.state_dict(), PATH)

net2 = MLP()
net2.load_state_dict(torch.load(PATH))
Y2 = net2(X)
Y2 == Y
tensor([[True],
        [True]])

模型的GPU计算

用torch.cuda.is_available()查看GPU是否可用

torch.cuda.is_available() 
False

查看GPU数量:

torch.cuda.device_count()
0

查看当前GPU索引号,索引号从0开始:

torch.cuda.current_device() # 输出 0
---------------------------------------------------------------------------

AssertionError                            Traceback (most recent call last)

<ipython-input-59-55c583da45aa> in <module>()
----> 1 torch.cuda.get_device_name(0)


D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\cuda\__init__.py in get_device_name(device)
    291             if :attr:`device` is ``None`` (default).
    292     """
--> 293     return get_device_properties(device).name
    294 
    295 


D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\cuda\__init__.py in get_device_properties(device)
    313 def get_device_properties(device):
    314     if not _initialized:
--> 315         init()  # will define _get_device_properties and _CudaDeviceProperties
    316     device = _get_device_index(device, optional=True)
    317     if device < 0 or device >= device_count():


D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\cuda\__init__.py in init()
    159     Does nothing if the CUDA state is already initialized.
    160     """
--> 161     _lazy_init()
    162 
    163 


D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\cuda\__init__.py in _lazy_init()
    176         raise RuntimeError(
    177             "Cannot re-initialize CUDA in forked subprocess. " + msg)
--> 178     _check_driver()
    179     torch._C._cuda_init()
    180     _cudart = _load_cudart()


D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\cuda\__init__.py in _check_driver()
     90 def _check_driver():
     91     if not hasattr(torch._C, '_cuda_isDriverSufficient'):
---> 92         raise AssertionError("Torch not compiled with CUDA enabled")
     93     if not torch._C._cuda_isDriverSufficient():
     94         if torch._C._cuda_getDriverVersion() == 0:


AssertionError: Torch not compiled with CUDA enabled

根据索引号查看GPU名字:

torch.cuda.get_device_name(0) # 输出 'GeForce GTX 1050'
---------------------------------------------------------------------------

AssertionError                            Traceback (most recent call last)

<ipython-input-60-45381b0646d2> in <module>()
----> 1 torch.cuda.get_device_name(0) # 输出 'GeForce GTX 1050'


D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\cuda\__init__.py in get_device_name(device)
    291             if :attr:`device` is ``None`` (default).
    292     """
--> 293     return get_device_properties(device).name
    294 
    295 


D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\cuda\__init__.py in get_device_properties(device)
    313 def get_device_properties(device):
    314     if not _initialized:
--> 315         init()  # will define _get_device_properties and _CudaDeviceProperties
    316     device = _get_device_index(device, optional=True)
    317     if device < 0 or device >= device_count():


D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\cuda\__init__.py in init()
    159     Does nothing if the CUDA state is already initialized.
    160     """
--> 161     _lazy_init()
    162 
    163 


D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\cuda\__init__.py in _lazy_init()
    176         raise RuntimeError(
    177             "Cannot re-initialize CUDA in forked subprocess. " + msg)
--> 178     _check_driver()
    179     torch._C._cuda_init()
    180     _cudart = _load_cudart()


D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\cuda\__init__.py in _check_driver()
     90 def _check_driver():
     91     if not hasattr(torch._C, '_cuda_isDriverSufficient'):
---> 92         raise AssertionError("Torch not compiled with CUDA enabled")
     93     if not torch._C._cuda_isDriverSufficient():
     94         if torch._C._cuda_getDriverVersion() == 0:


AssertionError: Torch not compiled with CUDA enabled

同Tensor类似,PyTorch模型也可以通过.cuda转换到GPU上。我们可以通过检查模型的参数的device属性来查看存放模型的设备。

net = nn.Linear(3, 1)
list(net.parameters())[0].device
device(type='cpu')

可见模型在CPU上,将其转换到GPU上:

net.cuda()
list(net.parameters())[0].device
---------------------------------------------------------------------------

AssertionError                            Traceback (most recent call last)

<ipython-input-62-a9404ab4efc6> in <module>()
----> 1 net.cuda()
      2 list(net.parameters())[0].device


D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\nn\modules\module.py in cuda(self, device)
    309             Module: self
    310         """
--> 311         return self._apply(lambda t: t.cuda(device))
    312 
    313     def cpu(self):


D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\nn\modules\module.py in _apply(self, fn)
    228                 # `with torch.no_grad():`
    229                 with torch.no_grad():
--> 230                     param_applied = fn(param)
    231                 should_use_set_data = compute_should_use_set_data(param, param_applied)
    232                 if should_use_set_data:


D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\nn\modules\module.py in <lambda>(t)
    309             Module: self
    310         """
--> 311         return self._apply(lambda t: t.cuda(device))
    312 
    313     def cpu(self):


D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\cuda\__init__.py in _lazy_init()
    176         raise RuntimeError(
    177             "Cannot re-initialize CUDA in forked subprocess. " + msg)
--> 178     _check_driver()
    179     torch._C._cuda_init()
    180     _cudart = _load_cudart()


D:\OfficeSoftware\pycharm\anaconda\lib\site-packages\torch\cuda\__init__.py in _check_driver()
     90 def _check_driver():
     91     if not hasattr(torch._C, '_cuda_isDriverSufficient'):
---> 92         raise AssertionError("Torch not compiled with CUDA enabled")
     93     if not torch._C._cuda_isDriverSufficient():
     94         if torch._C._cuda_getDriverVersion() == 0:


AssertionError: Torch not compiled with CUDA enabled
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值