一、网络模型的创建步骤
构建子模块就相当于卷积、激活、池化、全连接等操作,拼接子模块就相当于将所有子模块徐合成一个网络结构。在初始化网络结构时,加载的是__init__中进行的卷积、激活、池化、全连接等操作,将图像张量传入网络结构时进行的是forward()中的拼接好的网络结构
二、nn.Module属性
所有网络结构都需要继承nn.Module模块
在我们自己创建网络结构时需要进行导入nn模块,代码为
import torch.nn as nn
import torch.nn.functional as F
class Mynet(nn.Module):
def __init__(self, n_classes=10): # 识别多少类n_classes=10
...
def __forward__(self, x):
...
return x
下面介绍一下torch.nn中的比较重要的参数
1、nn.Parameter
张量子类,表示可学习参数,如weight、bias
2、nn.functional
函数具体实现,如卷积、激活函数、池化等
3、nn.init
提供参数初始化的方法
4、nn.Module
所有网络结构的基类,管理网络属性
nn.Module的参数:
parameters:存储管理nn.Parameter类
modules:存储管理nn.Modules类
buffers:存储管理缓冲属性,如BN层中的running_mean
***_hooks:存储管理钩子函数
总结:nn.Module是每个网络结构必须继承的类,一个网络module包含很多个小的module,一个module相当于一个小的运算,每个module都有8个字典来管理它的属性(上面图片中的属性)。