004.Module源码学习

004.Module源码学习

0.总序

在神经网络的搭建中,第一步往往都是继承nn.module的init方法,那么做为父类的Module中自然有很多的方法值得我们去研究,这对我们更好的使用神经网络打下了基础。而了解一个成熟体系的方法最直观的就是直接看它的源代码。所以今天就带着大家来欣赏(bushi)一下Module类里面经常用的一些方法

1.Module的树状结构

在一个现成的神经网络中,我们总是先init一些可学习参数(Parameters)后再进行前向(forward)过程,而在init的过程中,我们往往需要将数据先经过a处理再经过b处理再经过c处理等等,如下

def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(1, 32, 3, 1)
    self.conv2 = nn.Conv2d(32, 64, 3, 1) 
    self.dropout1 = nn.Dropout(0.25)
    self.dropout2 = nn.Dropout(0.5)
    self.fc1 = nn.Linear(9216, 128)
    self.fc2 = nn.Linear(128, 10)
    self.flatten = nn.Flatten(1)
    self.relu = nn.ReLU()

这时候上下之间实际上是一种嵌套的过程,即他们之间是一层套一层的,关于这部分,我们可以看官方文档的解释

r"""Base class for all neural network modules.

Your models should also subclass this class.
#wjs:这里表示了递归定义的本质
Modules can also contain other Modules, allowing to nest them in
a tree structure. You can assign the submodules as regular attributes::

    import torch.nn as nn
    import torch.nn.functional as F

    class Model(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(1, 20, 5)
            self.conv2 = nn.Conv2d(20, 20, 5)

        def forward(self, x):
            x = F.relu(self.conv1(x))
            return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their
parameters converted too when you call :meth:`to`, etc.

Modules can also contain other Modules, allowing to nest them ina tree structure(模型同样可以继承其它模型,就像树一样)

具体的方法可以通过 **list(net.named_children())**来查看具体的嵌套的子类型。源代码如下

def named_children(self) -> Iterator[Tuple[str, 'Module']]:
    memo = set()
    for name, module in self._modules.items():
        if module is not None and module not in memo:
            memo.add(module)
            yield name, module

可以看到反而这种访问子layer这个操作并不是通过递归来实现的,而是在init里面就已经初始化了一个 **ordered_dict[]**来存这些子模型。

为了更直观的理解这个所谓(树结构),我借鉴了一下 数据结构(bushi)里面很经典的二叉树的定义(具体分为什么前序中序后序,但这里没必要解释的这么详细)

#include<iostream>
//abc##de#g##f###

using namespace std;
const int N=1e5;
typedef struct bnode{
	char data;
	struct bnode *left,*right;//嵌套定义定义左右指针
}bnode,*bitree;
typedef struct {
	char *top;
	char *base;
	int size;
	}stack;
void creat(bitree &T)
{	char c;
	cin>>c;
	if(c=='#'){
		T=NULL;
		return; 
	}
	else
	{
	T=new bnode;
	creat(T->left);
	creat(T->right);
	T->data=c;
	}
}

void travel (bitree &t){
	if(t==NULL)return ;
	else{
		
		travel(t->left);
		cout<<t->data<<endl;
		travel(t->right);
		}
	}
int main(){
	bitree t;
	creat(t);
	travel(t);
}

可以发现,树状结构本质上还是递归调用,一层嵌套一层,所以我们在 遍历整体的每一个模型的时候,就要去通过递归(dfs)不断地去访问,才能到达每一层。

2.Parameter和Buffer

我们在模型初始化,有没有想过,为啥要将init和forward两个方法 分开写

def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(1, 32, 3, 1)
    self.conv2 = nn.Conv2d(32, 64, 3, 1) 
    self.dropout1 = nn.Dropout(0.25)
    self.dropout2 = nn.Dropout(0.5)
    self.fc1 = nn.Linear(9216, 128)
    self.fc2 = nn.Linear(128, 10)
    self.flatten = nn.Flatten(1)
    self.relu = nn.ReLU()
def forward(self, x):
    x = self.conv1(x)
    x = self.relu(x)
    x = self.conv2(x)
    x = self.relu(x)
    x = F.max_pool2d(x, 2)
    x = self.dropout1(x)
    x = torch.flatten(x, 1)
    x = self.fc1(x)
    x = F.relu(x)
    x = self.dropout2(x)
    x = self.fc2(x)
    output = F.log_softmax(x, dim=1)
    return output 

原因是:init里面实例化的参数只生成了一次,而forward里面的东西一直都要流动的去跑.即有可学习参数的模块必须放到init里面,否则将每次训练就会清零。

(我理解为为init就好像一个教室,而forward就好像学生和老师,我们在上课的时候,教室本身是不会变的,但老师和学生每天都要去不同的教室。)

从而引出来了模型训练中两种不同的数据类型

第一种是parameter类型,这种数据类型是不断需要优化的参数(可学习参数)的集合,比如weight和bias。我们在具体查看的时候可以通过具体查看某一层的parameter的方法来查看 net().layer._parameters 的方法来查看,当我们想查看所有的parameter的时候,我们可以通过 net().state_dict() 的方法来查看所有的parameter

第二种就是buffer类型,这种数据类型即在计算中间所需要用到的参数类型,并不参与整体的优化过程,比如说一些mean(平均值),var(方差),num_batches_tracked(跟踪在训练过程中已经处理的批次数量)。和上面一样,我们同样可以用 net().layer._buffers的方法来查看。

def parameters_demo():
    a = Net3()
    b = Net2()
    c = list(b.named_parameters())
    # c = list(b.named_children())
    d = b._modules
    dd = b._parameters
    f = b.conv1._parameters
    stat_dict = a.state_dict()
    output = b(torch.randn(4, 1, 28, 28))
    # print(stat_dict)
        
    print(c)
    print("======================")
    print(f)
    print("======================")
    #print(stat_dict)


def buffers_demo():
    b = Net2()
    c = list(b.named_buffers())
    d = list(b.named_children())
    f = b.conv1.buffers()
    stat_dict = b.state_dict()
    #output = b(torch.randn(4, 1, 28, 28))
    print(d)
    print("======================")
    print(c)
    print("======================")
    print(stat_dict)

运行结果如下


[('conv1', Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))), ('relu1', SELU()), ('bn1', BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), ('bn3', BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), ('conv2', Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2))), ('bn2', BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), ('fc1', Linear(in_features=21632, out_features=128, bias=True)), ('fc2', Linear(in_features=128, out_features=10, bias=True))]
======================
[('bn1.running_mean', tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])), ('bn1.running_var', tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])), ('bn1.num_batches_tracked', tensor(0)), ('bn3.running_mean', tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])), ('bn3.running_var', tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])), ('bn3.num_batches_tracked', tensor(0)), ('bn2.running_mean', tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.])), ('bn2.running_var', tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1.])), ('bn2.num_batches_tracked', tensor(0))]
======================
OrderedDict([('conv1.weight', tensor([[[[-0.1754, -0.0754,  0.0349, -0.0913,  0.0590],
          [ 0.1680, -0.1742,  0.0441, -0.1379,  0.0351],
          [-0.0803, -0.0672, -0.0796,  0.0693,  0.1558],
          [ 0.0459,  0.1850, -0.1862, -0.0813,  0.1675],
          [ 0.1175,  0.0641,  0.1137,  0.1618, -0.1781]]],


        [[[-0.0515, -0.1793,  0.0048, -0.0938, -0.1675],
          [ 0.0583,  0.1005, -0.0249, -0.0309, -0.0515],
          [ 0.0557, -0.1874,  0.1559,  0.0226,  0.1169],
          [ 0.0350,  0.1430, -0.1590, -0.1315, -0.0486],
          [ 0.0041, -0.1663,  0.0030, -0.1914,  0.0710]]],


        [[[-0.0461,  0.0359,  0.1834,  0.0892, -0.1025],
          [-0.0571, -0.0554,  0.0580,  0.0522,  0.0716],
          [ 0.0409,  0.0171,  0.0631,  0.1283,  0.1672],
          [-0.1058, -0.1598, -0.0087, -0.1819, -0.1446],
          [-0.0373,  0.0653,  0.0240, -0.1932,  0.0188]]],


        ...,


        [[[ 0.1763, -0.1416,  0.0120,  0.1276, -0.0239],
          [ 0.1351,  0.0474, -0.0038, -0.0564, -0.0113],
          [ 0.1957,  0.1413, -0.0678,  0.1587,  0.0729],
          [ 0.1074,  0.1266,  0.0871,  0.0886,  0.0994],
          [-0.0419, -0.0109, -0.0668,  0.0193, -0.0880]]],


        [[[ 0.1076,  0.0245,  0.1775,  0.0457, -0.0145],
          [-0.0630, -0.1941, -0.1122,  0.0476, -0.1586],
          [ 0.0600,  0.0179, -0.0333,  0.0549,  0.0228],
          [ 0.0780, -0.0122,  0.1359,  0.1737, -0.1565],
          [-0.1449,  0.0645, -0.0525, -0.1279,  0.1719]]],


        [[[-0.1401, -0.1280, -0.1858, -0.1158, -0.0588],
          [ 0.0472,  0.0838,  0.0262, -0.1322, -0.0763],
          [-0.0145, -0.1838,  0.0691,  0.1344,  0.0162],
          [-0.0716,  0.1610, -0.0075,  0.1444, -0.0991],
          [ 0.0290, -0.1872, -0.0896,  0.0616,  0.1906]]]])), ('conv1.bias', tensor([ 0.1150,  0.1484,  0.0720, -0.0048,  0.0364,  0.1806,  0.0221,  0.1796,
        -0.1551,  0.0646,  0.0775,  0.1216,  0.1040, -0.1490,  0.0326,  0.0402,
         0.0575,  0.1050,  0.0388,  0.1828,  0.1802,  0.1670, -0.0784,  0.0902,
        -0.0326,  0.0478, -0.1534, -0.1012,  0.0066,  0.0712,  0.0815,  0.1742,
         0.1442,  0.1291,  0.0320, -0.0006, -0.1638,  0.1013,  0.1527, -0.1625,
         0.0796,  0.0247,  0.1280, -0.0027,  0.1191, -0.0581, -0.0853, -0.1932,
         0.1886,  0.0957, -0.0298,  0.1406, -0.0762, -0.1394, -0.0891, -0.1856,
        -0.1415, -0.0177, -0.0676, -0.1127,  0.1902, -0.0872, -0.0990,  0.1869])), ('bn1.weight', tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])), ('bn1.bias', tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])), ('bn1.running_mean', tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])), ('bn1.running_var', tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])), ('bn1.num_batches_tracked', tensor(0)), ('bn3.weight', tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])), ('bn3.bias', tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])), ('bn3.running_mean', tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])), ('bn3.running_var', tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])), ('bn3.num_batches_tracked', tensor(0)), ('conv2.weight', tensor([[[[ 1.3804e-02,  2.6726e-03,  2.7326e-02],
          [ 3.2897e-03,  4.1362e-02, -2.4802e-02],
          [ 1.8567e-02, -2.8362e-02, -2.5149e-02]],

         [[ 1.9172e-03, -1.2559e-02, -1.9507e-02],
          [-6.1990e-03, -3.8953e-02,  2.7368e-02],
          [-2.6011e-03, -1.0121e-02,  8.6531e-03]],

         [[ 2.2412e-02, -2.7161e-02, -4.4404e-03],
          [ 3.9468e-02, -1.6361e-02, -1.3781e-02],
          [ 2.3425e-02,  4.0472e-02, -1.7511e-03]],

         ...,

         [[ 4.1043e-02, -1.0555e-02, -2.6990e-02],
          [ 2.6584e-02, -2.6117e-03,  1.1389e-02],
          [-2.0494e-04,  3.9066e-02, -3.3418e-02]],

         [[-1.0516e-03, -4.1187e-02,  1.8196e-02],
          [ 8.1509e-03, -1.8520e-02,  2.8867e-02],
          [-3.2531e-02, -2.8833e-02, -3.8639e-02]],

         [[ 3.1313e-02,  2.8149e-02, -3.1794e-02],
          [-3.7953e-02, -4.2451e-03, -1.3679e-02],
          [-2.5701e-05, -1.4094e-02, -1.6580e-02]]],


        [[[ 1.7749e-02,  2.4568e-03,  2.9429e-02],
          [-2.4494e-02,  3.2476e-02,  3.2034e-02],
          [ 1.9105e-02, -5.0577e-03, -1.0027e-02]],

         [[-1.6145e-02,  4.0818e-02,  1.5836e-02],
          [-4.9933e-03,  3.9374e-02,  1.4519e-02],
          [ 4.1226e-02, -3.1400e-02,  3.3854e-02]],

         [[ 1.2728e-02, -9.5448e-03, -1.5378e-02],
          [-3.8991e-02,  9.2428e-03, -9.1718e-03],
          [ 3.2750e-02, -1.8908e-02,  1.1480e-02]],

         ...,

         [[ 3.5675e-02,  5.9206e-03, -2.0005e-02],
          [-3.8375e-02, -1.9865e-02,  2.2279e-03],
          [ 2.7246e-03, -1.2991e-02, -1.8970e-03]],

         [[-3.5456e-02,  1.1164e-02, -2.0951e-02],
          [ 8.9100e-04,  2.8497e-02, -3.6699e-02],
          [ 3.6944e-02,  1.8426e-02, -1.8736e-02]],

         [[ 4.4348e-04, -3.0166e-02,  1.5241e-02],
          [ 3.9226e-02, -6.5979e-03,  2.7491e-02],
          [-3.9068e-02,  3.7548e-02, -2.7914e-02]]],


        [[[ 1.3020e-02,  2.1492e-03, -4.0598e-04],
          [-3.0876e-02,  3.6492e-02,  8.0021e-03],
          [-3.0781e-02,  2.9084e-02,  3.4763e-02]],

         [[-3.4699e-02, -4.1511e-02, -3.6165e-02],
          [ 5.6653e-03, -9.0752e-05, -1.7392e-02],
          [ 2.5487e-03, -2.7225e-03, -3.8598e-02]],

         [[-1.6121e-02, -2.4392e-02,  1.9272e-02],
          [ 3.6324e-02, -1.4379e-02,  3.6113e-02],
          [-1.6121e-02, -1.1428e-02, -3.5246e-02]],

         ...,

         [[-2.1558e-02, -3.0543e-02,  4.0315e-02],
          [ 3.4141e-02, -2.4406e-02,  2.3893e-02],
          [-3.2582e-02,  1.9556e-03, -1.4825e-02]],

         [[ 4.0184e-02, -2.8116e-02, -1.5306e-02],
          [-2.8341e-02,  1.0730e-02, -1.9845e-03],
          [-3.4924e-03,  1.5321e-02, -3.5831e-02]],

         [[ 3.1357e-02, -1.1074e-02, -1.6301e-02],
          [-4.2827e-03,  1.7536e-02,  4.0750e-02],
          [-3.0957e-02, -2.8302e-03, -2.1691e-02]]],


        ...,


        [[[ 2.8327e-03,  2.5519e-02, -2.0793e-04],
          [ 1.7179e-02, -2.3834e-02, -1.0083e-02],
          [-1.1535e-02, -2.7301e-03,  5.7620e-04]],

         [[-2.6855e-02,  2.9829e-02, -3.1412e-02],
          [-1.0995e-02,  4.0944e-02,  1.9222e-02],
          [-3.5822e-02, -1.9807e-02, -1.9813e-02]],

         [[-3.5791e-02, -2.3066e-02,  3.5855e-03],
          [-3.7625e-02,  3.0001e-02,  1.7222e-02],
          [ 3.7793e-03,  6.5467e-03,  2.0791e-02]],

         ...,

         [[-1.4132e-02, -1.2298e-02, -3.3130e-03],
          [ 2.2667e-02,  3.9222e-02,  3.4434e-03],
          [ 4.1459e-02,  3.6427e-02,  2.2059e-02]],

         [[ 3.2449e-02, -3.1477e-02, -1.2655e-03],
          [-2.0479e-02, -1.1059e-02,  1.9181e-03],
          [-1.4594e-02, -1.0662e-03, -3.2223e-02]],

         [[-2.3213e-02, -2.5523e-02, -3.8667e-02],
          [-3.7524e-02, -1.9326e-02, -3.1330e-02],
          [ 1.7482e-02, -2.6733e-02, -3.1149e-02]]],


        [[[ 1.3412e-02, -1.2247e-02,  3.4369e-02],
          [ 2.7067e-02,  1.4548e-02, -1.7364e-02],
          [-5.8489e-03, -3.0450e-02, -2.4052e-02]],

         [[ 7.5185e-04,  3.5632e-02,  3.3540e-02],
          [-3.8581e-02, -3.5501e-02,  1.3043e-02],
          [-4.1002e-02, -1.5784e-02,  8.6450e-03]],

         [[-1.4037e-02,  4.0806e-02,  7.5803e-04],
          [ 3.3773e-03, -3.4202e-02, -3.7942e-02],
          [ 3.6138e-02,  1.8666e-02, -4.0556e-02]],

         ...,

         [[-1.5326e-02, -2.8451e-02, -2.9994e-02],
          [-9.7842e-03,  1.8953e-02,  1.3231e-02],
          [ 1.8628e-02,  1.9107e-02, -2.9562e-02]],

         [[-1.6541e-02, -2.7337e-02,  2.5576e-03],
          [ 3.9585e-02, -1.1785e-02, -1.4699e-02],
          [ 1.1857e-02,  2.0325e-02,  3.2569e-04]],

         [[ 1.6132e-02,  4.1542e-02, -2.6733e-02],
          [-2.2193e-02, -7.7598e-03,  3.0271e-02],
          [ 2.1045e-02, -2.1218e-02,  3.8681e-02]]],


        [[[ 1.5345e-02, -3.6382e-02,  7.7098e-03],
          [-6.3445e-03, -3.6145e-02,  2.4919e-02],
          [ 1.9101e-03,  4.3644e-03, -3.4439e-02]],

         [[ 1.0588e-02,  3.8859e-02,  1.7641e-02],
          [-4.0048e-02,  3.4477e-02,  2.2445e-02],
          [-9.7489e-03,  2.8978e-02,  3.7072e-02]],

         [[-2.6893e-02,  2.5959e-02, -2.2724e-02],
          [ 3.9606e-02, -1.1915e-02, -2.4960e-02],
          [-1.1373e-02, -2.9400e-03, -3.8703e-03]],

         ...,

         [[-2.5881e-02, -1.8536e-02, -3.1195e-02],
          [-8.5401e-04, -1.8786e-02, -1.1179e-02],
          [ 4.0288e-02, -7.3430e-03,  3.6102e-02]],

         [[ 2.4128e-02,  1.2720e-03,  2.3537e-02],
          [ 3.3085e-02,  1.5664e-03, -2.5634e-02],
          [ 1.9059e-02, -2.2312e-02,  6.9598e-03]],

         [[-1.3052e-02,  3.4003e-02, -3.9267e-02],
          [ 4.1218e-02, -3.1411e-03, -1.3546e-02],
          [-1.5814e-02, -4.1339e-02,  2.3463e-02]]]])), ('conv2.bias', tensor([ 0.0243,  0.0328, -0.0063,  0.0163, -0.0054, -0.0012, -0.0274, -0.0332,
         0.0172, -0.0152,  0.0381,  0.0099,  0.0282, -0.0174,  0.0068,  0.0181,
         0.0212, -0.0107, -0.0189, -0.0200, -0.0160, -0.0290,  0.0316,  0.0059,
        -0.0104,  0.0409,  0.0031,  0.0217, -0.0117, -0.0142,  0.0187,  0.0359,
         0.0163,  0.0348,  0.0351, -0.0167,  0.0218, -0.0179, -0.0079,  0.0243,
         0.0405, -0.0330,  0.0315, -0.0345, -0.0339, -0.0288,  0.0152, -0.0069,
         0.0262, -0.0338, -0.0161,  0.0414, -0.0065,  0.0235,  0.0256, -0.0314,
        -0.0277, -0.0102,  0.0079, -0.0031, -0.0184,  0.0025,  0.0133,  0.0335,
        -0.0284,  0.0034, -0.0300,  0.0088,  0.0391, -0.0415,  0.0262, -0.0173,
         0.0279,  0.0413,  0.0196,  0.0048,  0.0165,  0.0300,  0.0255, -0.0371,
         0.0385, -0.0302,  0.0162, -0.0170,  0.0199, -0.0034, -0.0105, -0.0158,
         0.0238, -0.0238, -0.0228, -0.0146, -0.0006, -0.0129, -0.0011,  0.0206,
         0.0192, -0.0369, -0.0282, -0.0029, -0.0286, -0.0055,  0.0239,  0.0186,
        -0.0261,  0.0229, -0.0216, -0.0257,  0.0380, -0.0396, -0.0117, -0.0319,
         0.0098,  0.0413, -0.0033,  0.0333, -0.0222,  0.0070,  0.0086,  0.0028,
        -0.0359,  0.0263, -0.0011, -0.0368, -0.0333,  0.0153, -0.0241, -0.0369])), ('bn2.weight', tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1.])), ('bn2.bias', tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.])), ('bn2.running_mean', tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.])), ('bn2.running_var', tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1.])), ('bn2.num_batches_tracked', tensor(0)), ('fc1.weight', tensor([[ 0.0030,  0.0066,  0.0016,  ...,  0.0057, -0.0011, -0.0002],
        [-0.0067,  0.0027, -0.0005,  ...,  0.0032,  0.0041, -0.0031],
        [ 0.0044,  0.0020,  0.0044,  ..., -0.0038, -0.0053, -0.0035],
        ...,
        [ 0.0062, -0.0024, -0.0062,  ...,  0.0057,  0.0038, -0.0057],
        [-0.0050,  0.0020,  0.0054,  ...,  0.0001, -0.0044,  0.0024],
        [ 0.0002,  0.0006,  0.0049,  ...,  0.0011, -0.0043, -0.0046]])), ('fc1.bias', tensor([-7.3116e-04, -1.3034e-03, -2.8664e-03,  3.9936e-03, -4.2440e-03,
        -4.9437e-03,  4.3636e-03,  3.7657e-03,  3.1575e-03,  3.5778e-04,
         3.4981e-03, -3.3745e-04,  5.9981e-03,  3.5468e-03,  8.0280e-04,
        -1.6093e-04, -2.4811e-03,  5.5249e-03, -4.6965e-03, -2.3584e-04,
         3.6335e-04,  1.5001e-03,  4.5147e-03,  3.8937e-03, -1.3808e-03,
         1.4258e-03, -5.8644e-03,  1.2596e-03, -5.5135e-03,  5.6424e-03,
         6.4816e-03, -4.1891e-04,  6.4315e-04,  1.2446e-03, -3.3439e-03,
        -5.7036e-03,  2.6038e-03, -5.4193e-03,  7.4272e-04,  3.3243e-03,
         2.8245e-03,  4.7096e-03,  5.1189e-03,  4.9899e-03,  3.5888e-03,
         4.8485e-04,  6.0303e-03,  1.9599e-03, -7.5914e-04,  2.0569e-03,
         4.0892e-03, -5.1135e-03,  9.6662e-04,  3.3163e-03,  8.2273e-05,
         4.3252e-03,  8.8444e-04, -1.2848e-03,  1.2711e-03,  4.0505e-03,
        -7.5686e-04, -3.8382e-03,  1.8764e-03,  3.4371e-03,  4.6230e-03,
         1.4412e-03,  1.0875e-03,  6.3353e-03, -4.4848e-03, -3.2108e-03,
        -5.6673e-03, -3.1792e-03, -2.0406e-03, -5.9893e-03, -5.9565e-03,
        -8.3298e-04,  6.5114e-03,  2.4717e-03,  1.2593e-03, -3.8248e-03,
         6.4712e-03, -3.8312e-03, -3.9619e-03, -2.8453e-03,  1.3072e-03,
         6.2060e-03,  6.5183e-03, -6.4530e-03,  5.4113e-03,  5.5932e-03,
         2.6481e-03,  6.6035e-03,  2.6949e-03,  2.7214e-03,  3.5480e-03,
         6.2800e-03,  2.8863e-03,  4.4883e-03, -1.0256e-03, -5.5832e-03,
        -1.9126e-03,  1.2627e-03, -2.9106e-03,  5.4015e-03,  4.1861e-03,
        -3.8516e-03, -5.0588e-03, -3.6308e-03, -4.8302e-03, -3.0415e-03,
         2.4725e-04,  5.1073e-03,  2.3462e-03, -8.5401e-04,  5.5719e-03,
        -6.4838e-03, -6.0859e-03,  5.7448e-03, -9.4883e-04,  4.6147e-03,
        -2.0898e-04, -6.5476e-03,  3.2324e-03, -3.0787e-03,  3.6678e-04,
        -1.5722e-03,  5.1291e-03,  1.9171e-03])), ('fc2.weight', tensor([[ 0.0205,  0.0124, -0.0703,  ..., -0.0794, -0.0599, -0.0617],
        [ 0.0632,  0.0400, -0.0236,  ..., -0.0381,  0.0784, -0.0674],
        [ 0.0012, -0.0802,  0.0272,  ..., -0.0795, -0.0014,  0.0882],
        ...,
        [-0.0591, -0.0537, -0.0850,  ...,  0.0055,  0.0879, -0.0645],
        [ 0.0862,  0.0018, -0.0836,  ..., -0.0659,  0.0293, -0.0454],
        [-0.0539,  0.0060, -0.0491,  ..., -0.0495, -0.0264, -0.0669]])), ('fc2.bias', tensor([ 0.0705,  0.0022, -0.0763, -0.0407, -0.0482, -0.0100, -0.0798,  0.0247,
         0.0049,  0.0236]))])

进程已结束,退出代码为 0

PS:detach方法

这个方法本来不属于module里面定义的,而是tensor里面的,但是在module源代码里面却频繁的去使用,所以我不得不单独拿出来说说,

先看看官方说法

https://pytorch.org/docs/stable/autograd.html#tensor-autograd-functions

torch.Tensor.detach
Returns a new Tensor, detached from the current graph.

torch.Tensor.detach_
Detaches the Tensor from the graph that created it, making it a leaf.

创建一个新的tensor,并从当前的图中去分离,但要注意的是

新的tensor和原来的tensor共用一段内存 ,即修改一个,另外一个也会随之改变,所以.detach之后 若修改了的数据对反向过程的计算有影响的话,则会报错 反而,如果没有影响的话,那么修改了之后也无所谓

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
b=a.sigmoid()
c=b.sum()
d=c.relu()
print(a.grad)
print(a.detach().zero_())
d.backward()
print(b.detach().zero_())
d.backward()
print(c.detach().zero_())
d.backward()
print(d.detach().zero_())
d.backward()

这里分离了a和c之后再进行原地操作的时候是没有任何影响的,只有在分离了b和d之后才会产生影响,因为对于sum函数求导之后还是0.而对于sigmod和relu函数求导之后改变化会对导函数产生影响

Traceback (most recent call last):
  File "C:\ProgramData\Anaconda3\Lib\site-packages\torch\nn\modules\temp01.py", line 11, in <module>
    d.backward()
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\_tensor.py", line 522, in backward
    torch.autograd.backward(
  File "C:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

3.模型的保存与加载

  • 动态图与静态图:这是个很抽象的概念,但通过模型加载这个案例可以很好的去理解。大多数情况下,pytorch都是动态图存储 意味着只有当运行的时候,我们在里面写的网络才得以运行,而相反,静态图即是在书写的时候就已经对网络造成了影响。懂了上面这番话,就很好去理解模型的加载机制了。

  • 动态图的保存机制 :动态图的保存之前必须要先实例化网络再去保存,要不然运行的永远是历史保存的数据做不到更新这个操作.

  • 动态图的加载机制:说到底,由于动态图保存的永远只是模型的参数(

    weight,bias等parameter。)而不会去保存它的forward过程,所以在定义的时候,前向过程就显得可有可无(bushi),但我们的forward还是要有的,要不然网络跑不起来(qwq)。

from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def save_demo_v1():
    #wjs:(动态图)保存的只是init中的模块,weight,buffers三块内容,而不是完全
    model = Net()
    input = torch.rand(1, 1, 28, 28)
    output = model(input)
    torch.save(model, "mnist.pt")  # 4.6M : 保存


def load_demo_v1():
    model = torch.load("mnist.pt")
    input = torch.rand(1, 1, 28, 28)
    output = model(input)
    print(f"output shape: {output.shape}")


def save_para_demo():
    #wjs:先实例化对象,保存所有的数据,包括parameters和buffers
    model = Net()
    torch.save(model.state_dict(), "mnist_para.pth")


def load_para_demo():
    param = torch.load("mnist_para.pth")
    model = Net()
    model.load_state_dict(param)
    input = torch.rand(1, 1, 28, 28)
    output = model(input)
    print(f"output shape: {output.shape}")


def tensor_save():
    tensor = torch.ones(5, 5)
    torch.save(tensor, "tensor.t")
    tensor_new = torch.load("tensor.t")
    print(tensor_new)


def load_to_gpu():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = torch.load('mnist.pt', map_location=device)
    print(f"model device: {model}")


def save_ckpt_demo():
    model = Net()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    loss = torch.Tensor([0.25])
    epoch = 10
    #wjs:模型中断的时候需要保存的数据
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        # 'loss': loss.item(),
        # 可以添加其他训练信息
    }

    torch.save(checkpoint, 'mnist.ckpt')


def load_ckpt_demo():
    checkpoint = torch.load('model.ckpt')
    model = Net()  # 需要事先定义一个net的实例
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    input = torch.rand(1, 1, 28, 28)
    output = model(input)
    print("output shape: ", output.shape)


def save_trace_model():
    #wjs:变成静态图
    model = Net().eval()
    # 通过trace 得到了一个新的model,我们最终保存的是这个新的model
    traced_model = torch.jit.trace(model, torch.randn(1, 1, 28, 28))
    traced_model.save("traced_model.pt")


def load_trace_model():
    mm = torch.jit.load("traced_model.pt")
    output = mm(torch.randn(1, 1, 28, 28))
    print("load model succsessfully !")
    print("output: ", output)


if __name__ == "__main__":
    # save_demo_v1()
    # load_demo_v1()
    # save_para_demo()
    # load_para_demo()
    # tensor_save()
    # load_to_gpu()
    # save_trace_model()
    # load_trace_model()
    # save_ckpt_demo()
    # load_ckpt_demo()
    print("run save_load_demo.py successfully !!!")
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        '''self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)'''
        self.convp=nn.Conv2d(1, 32, 3, 1)

    def forward(self, x):
        '''x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)'''
        return 0
def load_demo_v1():
    model = torch.load("mnist.pt")
    input = torch.rand(1, 1, 28, 28)
    output = model(input)
    print(f"output shape: {output.shape}")

当我们更改了网络的结构之后再去加载,可以看到仍然是保存的是以前的模型数据, 你说的对,可这就是动态图,加载慢,事多,但就因为容易上手好学,编译过程像python就一直用…

print(model)
Net(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (dropout1): Dropout(p=0.25, inplace=False)
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=9216, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

PS:sequential方法

一个网络的传统写法包括很多,通常包括要先写init再写forward,前面也讲过为什么要分开写,但能不能直接将二者结合一下,即在forward的过程中将网络布置,所以就有了 sequential这个方法。

引用一下~~~~:

9.Sequential的介绍和神经网络搭建实战_sequential神经网络-CSDN博客

这是传统写法

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui,self).__init__()
        self.conv1=Conv2d(3,32,5,padding=2) 
        self.maxpool1=MaxPool2d(2)
        self.conv2=Conv2d(32,32,5,padding=2)
        self.maxpool2=MaxPool2d(2)
        self.conv3=Conv2d(32,64,5,padding=2)
        self.maxpool3=MaxPool2d(2)
        #经过Flatten
        self.flatten=Flatten()
        self.linear1=Linear(1024,64)
        self.linear2=Linear(64,10)
        
    def forward(self,x):
        x=self.conv1(x)
        x=self.maxpool1(x)
        x=self.conv2(x)
        x=self.maxpool2(x)
        x=self.conv3(x)
        x=self.maxpool3(x)
        x=self.flatten(x)
        x=self.linear1(x)
        x=self.linear2(x)
        return x

这是便捷写法

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x

可以看见,方便了不少。

但是从初学者的角度来说,还是第一种比较麻烦的写法能够更好的帮助去理解网络的训练过程,即什么是可学习参数,那些是流动的数据那些是固定的数据

但我还是会选第二种,哈哈

  • 16
    点赞
  • 51
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值