环境
- PyCharm Community Edition 2021.3.1
- Pytorch
代码实现
CNN卷积网络代码
import torch
from torch import nn
from torchvision import datasets
from torchvision import transforms
from torch.autograd import Variable
from torch.utils.data import DataLoader
from PIL import Image
import matplotlib.pyplot as plt
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.layer1 = nn.Sequential(nn.Conv2d(3,16,3,padding=1), # 第一个卷积层,输入通道数3,输出通道数16,卷积核大小3*3
nn.ReLU(True), # 第一次卷积结果经过ReLU激活函数处理
nn.MaxPool2d(kernel_size=2, stride=2) # 第一次池化,池化大小2*2,方式Max pooling
)
self.layer2 = nn.Sequential(nn.Conv2d(16,16,3,padding=1),
nn.ReLU(True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.fc = nn.Sequential(nn.Linear(56 * 56 * 16, 128),# 第一个全连接层,线性连接,输入节点数56*56*16,输出节点数128
nn.ReLU(True),
nn.Linear(128, 64),# 第二个全连接层,线性连接,输入节点数128,输出节点数64
nn.ReLU(True),
nn.Linear(64, 2)# 第三个全连接