官网中文文档 神经网络
核心代码
首先介绍一下 torch.nn.Conv2d(),传入参数的含义如下:
in_channels # 输入通道数
out_channels # 输出通道数
kernel_size # 卷积核尺寸,常见有 1,3,5,7
stride # 步长,默认为1
padding # 填充,默认零填充
dilation # 空洞卷积,默认为 1
groups # 组卷积,默认为 1
bias # 是否需要偏置,默认为 True
和原代码在形式上稍微有点不同,这里使用了 nn.Sequential() 模块快速进行搭建。上一层的输出直接作为下一层的输入。输入要求为 1 * 1 * 32 * 32 的四维张量。
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.features = nn.Sequential(
# 输入通道数为 1,输出通道数为 6,有 6 个 1*5*5 卷积核
nn.Conv2d(1, 6, 5),
nn.MaxPool2d(2, 2),
# 输入通道数为 6,输出通道数为 16,有 16 个 6*5*5 卷积核
nn.Conv2d(6, 16, 5),
nn.MaxPool2d(2, 2),
)
self.classifier = nn.Sequential(
nn.Linear(16*5*5, 120),
nn.ReLU(True),
nn.Linear(120, 84),
nn.ReLU(True),
nn.Linear(84, 10),
)
def forward(self, x):
# 卷积
x = self.features(x)
# x 为 4 维张量。把 x 的尺寸调整为 [1, 16*5*5]
x = x.view(x.size(0), -1)
# 分类
x = self.classifier(x)
# x 的尺寸为 [1, 10]
return x
卷积
以 1 * 1 * 32 * 32 的输入为例。第一个 1 表示 batchsize,第二个 1 表示通道数(channel),后面三个参数可视为一个立方体。conv 表示卷积,pooling 表示池化。
![](https://img-blog.csdnimg.cn/20210116174230617.png)
卷积 + 分类
![](https://img-blog.csdnimg.cn/20210112200530492.png)
网络架构
print(Net())
输出
Net(
(features): Sequential(
(0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(classifier): Sequential(
(0): Linear(in_features=400, out_features=120, bias=True)
(1): ReLU(inplace=True)
(2): Linear(in_features=120, out_features=84, bias=True)
(3): ReLU(inplace=True)
(4): Linear(in_features=84, out_features=10, bias=True)
)
)
查看参数
params = list(net.parameters())
len(params)
# conv1's .weight
params[0].size()
# conv2's .weight
params[2].size()
输出
10
torch.Size([6, 1, 5, 5]) # 表示 6 个 1 * 5 * 5 的卷积核
torch.Size([16, 6, 5, 5]) # 表示 16 个 6 * 5 * 5 的卷积核
这 10 个参数分别是
conv1.weight
conv1.bias
conv2.weight
conv2.bias
fc1.weight
fc1.bias
fc2.weight
fc2.bias
fc3.weight
fc3.bias
如果想要详细查看个参数的具体数值,这样
# 查看某个参数数值
params[0]
# 或查看所有参数数值
for param in net.parameters():print(param)