CNN实现对FashionMNIST图像分类
卷积神经网络相对于全连接神经网络的优势:
- 参数少 -> 权值共享
因为全连接神经网络输入的图片像素较大, 所以参数较多
而卷积神经网络的参数主要在于核上, 而且核的参数可以共享给其他通道 - 全连接神经网络会将输入的图片拉直, 这样就会使图片损失原来的效果,从而导致效果不佳
而卷积神经网络不会将图片拉直,用步长去移动核 - 可以手动选取特征,训练好权重,特征分类效果比全连接神经网络的效果好
CNN过程:
conolution层: 实现对feature map局部采样(相似于感受野)
pooling层: 增加感受野
dense层: 也就是全连接层
大概思路
- 加载数据集
- 构建CNN模型
- 训练函数
- 训练模型
- 可视化效果
使用pytorch实现CNN
使用的是Fashimnist数据集, 和以前的线性回归加载数据集的方式一样
日常导入需要用到的python库
import torch
import torch.nn as nn
import torch.optim as optim
import torch.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
加载数据集
与线性回归一样, 就不再阐述
# 加载数据集
train_data = torchvision.datasets.FashionMNIST("/home/kesci/input/FashionMNIST2065",
train=True,
transform=transforms.ToTensor(),
download=False)
test_data = torchvision.datasets.FashionMNIST("/home/kesci/input/FashionMNIST2065",
train=False,
transform=transforms.ToTensor(),
download=False)
# 批量加载数据
train_iter = torch.utils.data.DataLoader(train_data, batch_size=64,
shuffle=True,
num_workers=4)
test_iter = torch.utils.data.DataLoader(train_data, batch_size=64,
shuffle=False,
num_workers=<