文章目录
一、定义网络(Net)
用一个名为Net的类定义
需要继承torch.nn中的nn.Module(注意M大写)
Net类包括初始化函数和forward函数两部分
1)初始化
init_(self): 放置有可学习参数的层(注意init前后均是两个下划线)
a)对nn.Module初始化: super(Net, self)init()
b)定义卷积和全连接操作(用到nn.Conv2d(), nn.Linear())
2)前向操作
forward(self, x)
输入x,按照网络前向传播步骤,调用初始化中定义的卷积和全连接操作,得到最后输出,并return。
如下简单定义一个cnn模型:
class SVHN_Model1(nn.Module):
#初始化
def __init__(self):
super(SVHN_Model1,self).__init__()
##CNN提取模块
self.cnn=nn.Sequential(
nn.Conv2d(3,16,kernel_size=(3,3),stride=(2,2)),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(16,32,kernel_size=(3,3),stride=(