利用Pytorch搭建CNN网络
这个实验用到的数据集是Mnist数据集,图片维度是1×28×28
import torch.nn as nn
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
# 卷积层
self.conv1 = nn.Sequential(
nn.Conv2d( # 图片的维度:(1, 28, 28)
in_channels= 1, # 图片的高度
out_channels= 16, # 输出的高度:filter的个数
kernel_size =5, # filter的像素点是5×5
stride = 1, # 每次扫描跳的范围
padding = 2 # 补全边缘像素点
), # 图片的维度:(16,28,28)
nn.ReLU(), # 图片的维度:(16,28,28)
nn.MaxPool2d(kernel_size=2,), # 图片的维度:(16,14,14)
)
# 卷积层
self.conv2 = nn.Sequential( # 图片的维度:(16,14,14)
nn.Conv2d(16,32,5,1,2), # 图片的维度:(32,14,14)
nn.ReLU(),
nn.MaxPool2d(2) # 图片的维度:(32&