一、前言
1.1 问题描述
使用PyTorch实现AlexNet,进行简单的图片分类。
使用AlexNet训练一个花分类数据集,实现输入一张RGB图片,输出该图片中花的类别。
花分类数据集下载地址:
https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
1.2 模型结构
从前往后分别是:
卷积层1,最大池化层1,卷积层2,最大池化层2,卷积层3,卷积层4,卷积层5,最大池化层3,全连接层1,全连接层2,全连接层3。
二、代码实现
2.1 项目结构
- data 数据集地址
- example 测试模型所使用的图片
- models 分类模型地址
- AlexNet.py AlexNet模型搭建
- weights 模型权重存放地址
- class_indices.json 数据集类别
- predict.py 测试模型的分类效果
- train.py 对模型进行训练
2.2 环境配置
- Python 3.10
- PyTorch 2.0.0
- Ubuntu 22.04
2.3 模型实现
AlexNet.py
import torch
import torch.nn as nn
class AlexNet(nn.Module):
def __init__(self, num_classes, init_weights=True):
super(AlexNet, self).__init__()
self.feature_extract = nn.Sequential(
nn.Conv2d(3, 48, kernel_size=11, padding=2, stride=4), # input: [3, 224, 224] output: [48, 55, 55]
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output: [48, 27, 27]
nn.Conv2d(48, 128, kernel_size=5, padding=2, stride=1), # output: [128, 27, 27]
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output: [128, 13. 13]
nn.Conv2d(128, 192, kernel_size=3, padding=1, stride=1), # output: [192, 13, 13]
nn.ReLU(inplace=True),
nn.Conv2d(192, 192, kernel_size=3, padding=1, stride=1), # output: [192, 13, 13]
nn.ReLU(inplace=True),
nn.Conv2d(192, 128, kernel_size=3, padding=1, stride=1), # output: [128, 13, 13]
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2), # output: [128, 6, 6]
)
self.classifier = nn.Sequential(
nn.Dropout(p=0.5),
nn.Linear(128 * 6 * 6, 2048),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5),
nn.Linear(2048, 2048),
nn.ReLU(inplace=True),
nn.Linear(2048, num_classes),
)
if init_weights:
self._initialize_weights()
def forward(self, x):
x = self.feature_extract(x)
x = torch.flatten(x, start_dim=1)
x = self.classifier(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
2.4 结果展示
测试图片:
测试结果: