模型构造与基本深度学习基础操作2020/04/01

本文介绍了如何在PyTorch中构造深度学习模型,包括继承MODULE类、Sequential子类的使用、访问和初始化模型参数、共享参数的方法,以及模型的读取和存储。重点讲解了torch.nn.init.normal_的初始化方法和state_dict的保存与加载。
摘要由CSDN通过智能技术生成

继承MODULE类来构造模型

定义MLP类重载module类的__init__ 和forward函数,分别用于创建模型和进行前向计算,无需定义反向传播函数,
系统会通过自动求梯度,生成backward函数

import torch
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden=nn.Linear(784,256)
        self.act=nn.ReLU()
        self.output=nn.Linear(256,10)
    def forward(self,X):
        a=self.act(self.hidden(X))
        return self.output(a)
#实例化网路
X=torch.rand(2,784)
net=MLP()
print(net)
net(X)
MLP(
  (hidden): Linear(in_features=784, out_features=256, bias=True)
  (act): ReLU()
  (output): Linear(in_features=256, out_features=10, bias=True)
)





tensor([[-0.1568, -0.1277,  0.1709,  0.1437,  0.0183,  0.2180, -0.1757, -0.2499,
          0.0646,  0.0176],
        [-0.0939, -0.1590,  0.0674,  0.2006, -0.0693,  0.1233, -0.1146, -0.2737,
         -0.0184, -0.0989]], grad_fn=<AddmmBackward>)

MODULE子类

(1)Sequential 便于简单的前向计算串联各层,具体实现细节:

class MySequential(nn.Module):
    def __init__(self,*args):
        super().__init__()
        if len(args)==1 and isinstance(args[0],OrderedDict):
            for key,module in args[0].items:
                self.add_module(key,module)
        else:
            for idx,module in enumerate(args):
                self.add_module(str(idx),module)
    def forward(self,input):
        for module in self._modules.values():
            input=module(input)
        return input
net=MySequential(nn.Linear(784,256),nn.ReLU(),nn.Linear(256,10))
print(net)
net(X)
MySequential(
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
)





tensor([[-0.0798,  0.1060, -0.0161, -0.0773, -0.0844,  0.2117,  0.2107, -0.0883,
          0.1431, -0.0817],
        [ 0.0164,  0.0234, -0.0436, -0.0638, -0.0055,  0.2168,  0.0212, -0.2043,
          0.0989, -0.0613]], grad_fn=<AddmmBackward>)

访问模型参数

print(type(net.named_parameters()))   #查看类型
<class 'generator'>
for name,param in net.named_parameters():
    print(name,param.size())
0.weight torch.Size([256, 784])
0.bias torch.Size([256])
2.weight torch.Size([10, 256])
2.bias torch.Size([10])

初始化模型参数

Pytorch中的nn.Module模块已经采用较为合理的初始化策略,没有特殊情况一般不会再进行。Pytorch提供了init模块进行多种预设的初始化方法

# 看例子
import torch.nn.init as init
for name,param in net.named_parameters():
    # 对weight
    if 'weight' in name:
        init.normal_(param,mean=0,std=0.01)
        print(name,param.data)
    if 'bias' in name:
        init.constant_(param,val=0)
        print(name,param.data)
0.weight tensor([[-7.3670e-03, -4.5059e-03,  2.5534e-03,  ..., -1.4920e-02,
          2.9664e-02,  9.6315e-03],
        [ 3.3703e-03, -1.7920e-02, -1.4533e-02,  ..., -7.4964e-04,
         -1.1880e-02,  6.7410e-03],
        [-1.5851e-02, -7.5730e-04, -1.4035e-03,  ..., -1.1753e-02,
         -8.3762e-03,  6.4255e-03],
        ...,
        [-8.0277e-04,  2.1240e-02,  6.7807e-03,  ...,  6.3259e-06,
         -1.1621e-02,  4.4899e-03],
        [-4.8244e-03, -7.8703e-03,  6.5461e-03,  ..., -3.9733e-03,
         -8.0912e-04, -8.6507e-03],
        [-1.6352e-02,  2.9723e-03,  2.4309e-03,  ..., -1.0128e-02,
          3.5137e-03, -1.3541e-02]])
0.bias tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
2.weight tensor([[ 5.3187e-04,  1.5396e-02, -9.6198e-03,  ...,  1.9582e-02,
          1.2500e-02,  8.8209e-05],
        [-7.1357e-03, -3.1403e-03,  5.3357e-03,  ...,  6.8758e-03,
         -6.4421e-03,  5.9430e-03],
        [ 2.2387e-02,  5.2646e-03, -3.2153e-03,  ...,  1.7143e-02,
          1.2387e-02, -2.1117e-02],
        ...,
        [-6.0646e-03, -7.9573e-03, -2.3083e-03,  ..., -2.7555e-03,
          5.7004e-04, -5.2114e-03],
        [-1.9535e-02, -9.3469e-03,  2.1729e-02,  ...,  2.7975e-02,
          7.1541e-04,  2.3038e-03],
        [-1.9913e-02, -1.9769e-03,  1.0276e-02,  ...,  7.5271e-03,
         -1.1302e-02,  5.0707e-03]])
2.bias tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

自定义初始化方法

我们先看torch.nn.init.normal_的具体实现

def normal_(tensor,mean=0,std=1):
    with torch.no_grad():
        return tensor.normal_(mean,std)
#将权重的一半概率初始化为0,另外一半初始化为[-10,-5],[5,10]两个区间的均匀分布
def init_weight_(tensor):
    with torch.no_grad():
        tensor.uniform_(-10,10)
        tensor*=(tensor.abs()>=5).float()
for name,param in net.named_parameters():
    # 对weight
    if 'weight' in name:
        init_weight_(param)
        print(name,param.data)
0.weight tensor([[-0.0000,  0.0000, -0.0000,  ..., -9.5116, -0.0000,  0.0000],
        [-5.8680,  0.0000,  0.0000,  ...,  0.0000, -0.0000,  9.0683],
        [ 9.6953, -0.0000,  0.0000,  ..., -5.2760, -0.0000, -8.7688],
        ...,
        [ 0.0000,  6.1618, -0.0000,  ..., -8.2309, -5.0961,  6.9231],
        [ 7.7732, -0.0000, -7.7075,  ...,  0.0000, -5.9438, -9.6617],
        [ 8.5203, -0.0000,  6.8366,  ..., -0.0000, -6.4314, -7.1154]])
2.weight tensor([[-0.0000, -6.0745,  6.4782,  ...,  9.5190, -0.0000,  0.0000],
        [ 0.0000, -7.1742, -0.0000,  ..., -0.0000,  7.9957, -0.0000],
        [-8.2802, -0.0000,  6.7018,  ..., -0.0000,  0.0000, -0.0000],
        ...,
        [-0.0000,  0.0000, -8.1016,  ..., -0.0000,  7.2197, -8.4890],
        [ 0.0000, -9.3670, -0.0000,  ...,  6.4142,  8.5989,  9.9789],
        [ 0.0000,  0.0000,  9.7598,  ..., -0.0000,  6.3064, -5.4405]])

共享参数

Module类的forward函数里多次调用同一个层

#eg
linear=nn.Linear(1,1,bias=False)
net=nn.Sequential(linear,linear)
print(net)
Sequential(
  (0): Linear(in_features=1, out_features=1, bias=False)
  (1): Linear(in_features=1, out_features=1, bias=False)
)

可见第一层和第二层共享参数

读取和存储

(1)读取TENSOR

x=torch.ones(3)
torch.save(x,'x.pt')#保存张量
x1=torch.load('x.pt')#加载张量
print(x1)

y=torch.zeros(4)
torch.save([x,y],'xy.pt')
xy_list=torch.load('xy.pt')#存储一个tensor列表,并读回内存
print(xy_list)

torch.save({'x':x,'y':y},'xy_dict.pt')
xy=torch.load('xy_dict.pt')
print(xy)
tensor([1., 1., 1.])
[tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]
{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}

小结:torch.save 可以保存各种对象,包括模型,张量和字典 torch.load使用pickle unpicke 工具将pickle对象文件反序列化为内存

(2) 读写模型

方式1:state_dict方式(推荐);这是一个从参数名隐射到参数Tensor的字典对象,只有具有可学习参数的才有,优化器中也具有一个。

#利用前面构建的MLP举例
net=MLP()
net.state_dict()
OrderedDict([('hidden.weight',
              tensor([[-0.0136, -0.0110,  0.0297,  ...,  0.0014, -0.0208, -0.0321],
                      [-0.0057, -0.0211,  0.0135,  ...,  0.0199, -0.0014,  0.0096],
                      [ 0.0280,  0.0117, -0.0006,  ...,  0.0331, -0.0080, -0.0237],
                      ...,
                      [-0.0246, -0.0046, -0.0349,  ...,  0.0085,  0.0017,  0.0251],
                      [-0.0058, -0.0061, -0.0205,  ..., -0.0287,  0.0247,  0.0005],
                      [-0.0143,  0.0321,  0.0212,  ..., -0.0084,  0.0244,  0.0306]])),
             ('hidden.bias',
              tensor([ 3.1305e-02, -6.9119e-03,  2.0941e-02, -2.2305e-02,  2.5125e-02,
                      -1.3433e-02,  3.1011e-02, -3.1288e-03, -3.4262e-02,  4.4943e-03,
                       3.1427e-02,  1.9461e-02, -3.3079e-02,  2.2021e-02, -1.4982e-02,
                      -1.7399e-02,  2.1805e-02,  3.1871e-02,  1.2301e-02,  3.0461e-03,
                       3.5593e-03, -8.0371e-03, -2.8687e-02, -2.5229e-02,  9.3432e-03,
                      -2.4308e-02, -1.0408e-02, -2.7482e-02,  8.1858e-03, -2.6377e-02,
                      -7.5215e-03, -8.7385e-03,  1.5541e-02,  3.3680e-02,  2.5204e-02,
                       1.5183e-02,  2.7199e-02, -1.4679e-02, -3.2512e-02, -3.2054e-02,
                      -2.8799e-02, -1.8439e-02, -1.1006e-02,  3.0108e-02,  2.2955e-02,
                      -5.9173e-03, -1.4669e-02,  4.8523e-03,  2.0565e-02, -2.2397e-02,
                       2.7658e-02,  6.3019e-03, -5.1514e-03, -7.8413e-03, -1.8985e-02,
                      -7.1641e-04, -5.4482e-03, -1.0629e-02, -1.4976e-02,  2.2032e-02,
                      -2.0785e-02,  2.8934e-02,  6.6899e-03,  3.4973e-02,  1.8431e-02,
                      -1.9760e-02,  2.1899e-02, -9.4859e-03,  3.1312e-03, -1.2290e-02,
                      -2.8553e-02, -1.7020e-02, -1.7534e-02,  3.0155e-02, -5.1623e-03,
                       2.6113e-03,  1.1324e-02,  2.6100e-02, -1.2954e-02, -1.0702e-02,
                      -6.4077e-03, -3.4773e-02, -4.5309e-03, -2.4982e-02, -3.3974e-02,
                      -3.0895e-03,  2.8445e-02,  1.6657e-02, -3.3164e-02, -9.9130e-03,
                      -3.0933e-02, -1.5069e-02,  3.0788e-02, -3.0066e-02,  1.8212e-02,
                       2.6435e-02, -2.5778e-02,  1.2167e-02,  2.1587e-02, -1.9604e-02,
                       1.0127e-02,  1.1425e-02, -1.1706e-02, -2.5564e-02, -3.4525e-02,
                      -2.4434e-02,  2.2406e-02, -3.1730e-02,  3.1080e-02,  1.3585e-03,
                       2.0328e-02, -2.7836e-02, -9.0502e-03,  2.8293e-02,  5.3455e-04,
                      -2.6237e-02,  2.2733e-02, -2.7324e-02,  2.7153e-02,  3.3261e-02,
                       3.2525e-02,  2.3496e-02, -1.2033e-02, -9.2284e-03, -2.3265e-02,
                      -6.9208e-04,  2.9772e-02,  2.4725e-02, -1.0888e-02, -5.4415e-03,
                       1.3602e-02, -5.0255e-03, -8.1237e-03,  1.8826e-02,  1.2347e-02,
                      -1.7835e-02,  5.7540e-03, -8.7569e-03,  3.5205e-02, -3.3670e-02,
                      -3.9439e-03, -2.2330e-02, -7.5736e-03, -5.0021e-03,  2.8512e-02,
                       2.9699e-03, -2.0811e-02,  3.0332e-02,  3.2435e-03, -1.1954e-02,
                      -2.4728e-02,  1.0801e-02, -1.6735e-02,  1.7002e-02,  1.7940e-02,
                       3.8781e-03, -2.1850e-02, -2.1234e-02,  2.0817e-02,  9.5619e-03,
                       2.4271e-02,  2.5589e-05,  1.7797e-02, -1.2524e-02,  6.4133e-03,
                      -6.3742e-03, -6.8082e-03,  4.8706e-03,  1.0436e-02, -3.4496e-02,
                      -2.0085e-02,  3.0331e-02,  3.4712e-02, -1.1799e-02,  7.4912e-03,
                      -2.0683e-02, -4.9508e-03, -9.4079e-03,  1.9821e-02,  1.2626e-02,
                       3.3640e-02, -7.0766e-03, -1.6350e-02,  2.7539e-02, -3.5491e-03,
                      -3.2590e-02, -1.9969e-02,  3.1202e-02, -3.4139e-02,  2.4393e-02,
                       1.4785e-02, -6.9166e-03,  3.1012e-02,  2.2647e-02,  2.6756e-02,
                       6.0994e-03,  2.0769e-02, -8.6712e-03, -2.8026e-02, -5.1754e-03,
                      -5.6501e-03,  2.0725e-03,  4.0739e-03, -9.4127e-03,  2.4582e-02,
                       1.5492e-02, -1.4416e-02, -2.5642e-02, -8.5237e-03, -9.0769e-04,
                       1.7156e-02,  2.5980e-02, -6.8641e-04,  3.5932e-03,  1.2633e-02,
                      -2.8234e-02, -2.8589e-02, -4.4682e-03, -3.0578e-02,  1.5988e-02,
                      -9.0113e-03,  7.1784e-03,  1.2726e-02, -2.0503e-02,  2.5928e-02,
                       8.7758e-03, -3.5033e-02,  1.8163e-02, -2.9047e-02, -1.8063e-02,
                       2.5760e-02,  6.9178e-03, -2.3813e-03, -1.6222e-02, -1.9902e-02,
                       6.9884e-03,  2.9882e-02,  2.0432e-02,  3.2332e-02, -2.1999e-02,
                       1.0688e-02, -2.9595e-02,  7.1982e-03, -2.0782e-02, -6.7366e-03,
                       9.8019e-03, -1.1958e-02, -2.6838e-02, -2.5685e-02, -2.0283e-03,
                      -1.0448e-02, -2.2837e-02, -2.5476e-02, -8.7900e-03,  1.9235e-02,
                      -1.1392e-02])),
             ('output.weight',
              tensor([[ 3.8825e-02,  1.8576e-02, -4.0230e-02,  ...,  2.9058e-02,
                        3.0307e-02,  3.8201e-02],
                      [ 4.5807e-02, -2.3931e-02,  2.2467e-02,  ...,  1.6682e-02,
                        3.3030e-02, -5.9647e-02],
                      [-3.0376e-05,  5.0309e-03,  3.1607e-02,  ...,  4.8286e-02,
                        4.3415e-02, -2.5420e-02],
                      ...,
                      [ 1.2002e-02, -4.3146e-02, -4.4974e-03,  ..., -2.2281e-02,
                       -4.0773e-02, -2.6248e-04],
                      [-5.5008e-02, -3.8148e-02, -3.4065e-02,  ...,  3.3826e-02,
                       -2.6733e-02,  1.8512e-02],
                      [ 9.4306e-03, -3.9237e-02, -5.6302e-02,  ..., -5.6227e-02,
                       -3.6273e-02,  2.9709e-02]])),
             ('output.bias',
              tensor([ 0.0620, -0.0539, -0.0383, -0.0328, -0.0176,  0.0069,  0.0100,  0.0595,
                      -0.0195,  0.0577]))])
#查看优化器的state_dict
optimizer=torch.optim.SGD(net.parameters(),lr=0.001)
optimizer.state_dict()
{'param_groups': [{'dampening': 0,
   'lr': 0.001,
   'momentum': 0,
   'nesterov': False,
   'params': [499798207584, 499798206792, 499798380120, 499798379112],
   'weight_decay': 0}],
 'state': {}}

** 保存和加载state_dict(推荐方式)**

X=torch.rand(2,784)
Y=net(X)
PATH="C:/Users/mingming/net.pt"
torch.save(net.state_dict(),PATH)  #保存

net2=MLP()
net2.load_state_dict(torch.load(PATH))  #加载模型
Y2=net2(X)
Y2==Y2        #检验模型参数是否一致
tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])

保存和加载整个模型

torch.save(net2,PATH)
model=torch.load(PATH)
print(model)
D:\Anaconda3\lib\site-packages\torch\serialization.py:360: UserWarning: Couldn't retrieve source code for container of type MLP. It won't be checked for correctness upon loading.
  "type " + obj.__name__ + ". It won't be checked "


MLP(
  (hidden): Linear(in_features=784, out_features=256, bias=True)
  (act): ReLU()
  (output): Linear(in_features=256, out_features=10, bias=True)
)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值