神经网络由对数据执行操作的层/模块组成。torch. nn
命名空间提供了构建自己的神经网络所需的所有构建块。PyTorch
中的每个模块都子类化了nn.Module
**。**神经网络是由其他模块(层)组成的模块本身。这种嵌套结构允许轻松构建和管理复杂的架构。
本次我们以CNN神经网络的开山之作LeNet
为例,介绍神经网络的建立。LeNet
是一个 7 层的神经网络,包含 3 个卷积层,2 个池化层,1 个全连接层。其中所有卷积层的所有卷积核kernel
都为 5x5,步长 stride
=1,池化方法都为全局 pooling
,激活函数为 Sigmoid
,网络结构如下:
模型的PyTorch
实现如下:
import torch
import torch.nn as nn
class LeNet(nn.Module): # 创建神经网络都需要继承 nn.Module
def __init__(self, intput_channle=1, output_channle=10): # 输入通道为1, shu'chu'ton
super().__init__()
self.conv1 = nn.Conv2d(intput_channle, 6, kernel_size=5, padding=2) # 二维卷积层
self.act1 = nn.Sigmoid() # Sigmoid()激活层
self.pool1 = nn.AvgPool2d(2) # 二维平均池化层
self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
self.act2 = nn.Sigmoid()
self.pool2 = nn.AvgPool2d(2)
self.flatten = nn.Flatten() # 将张量进行展平,但保持第一维度(dim=0)
self.linear1 = nn.Linear(16 * 5 * 5, 120) # 线性层
self.act3 = nn.Sigmoid()
self.linear2 = nn.Linear(120, 84)
self.act4 = nn.Sigmoid()
self.linear3 = nn.Linear(84, output_channle)
def forward(self, x): # 定义模型的前向操作
out = self.pool1(self.act1(self.conv1(x)))
out = self.pool2(self.act2(self.conv2(out)))
out = self.flatten(out)
out = self.act3(self.linear1(out))
out = self.act4(self.linear2(out))
out = self.linear3(out)
return out
我们通过继承nn.Module
来定义神经网络,并在__init__
中初始化神经网络层。每个nn.Module
子类都在forward()
中实现对输入数据的操作。
我们创建一个LeNet
的一个实例,并打印其结构。
model = LeNet() # intput_channle、output_channle参数值默认
print("Model structure: {model}")
Model structure: LeNet(
(conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(act1): Sigmoid()
(pool1): AvgPool2d(kernel_size=2, stride=2, padding=0)
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(act2): Sigmoid()
(pool2): AvgPool2d(kernel_size=2, stride=2, padding=0)
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear1): Linear(in_features=400, out_features=120, bias=True)
(act3): Sigmoid()
(linear2): Linear(in_features=120, out_features=84, bias=True)
(act4): Sigmoid()
(linear3): Linear(in_features=84, out_features=10, bias=True)
)
linear3): Linear(in_features=84, out_features=10, bias=True)
)