# Pytorch 0.4.0 VGG16实现cifar10分类.
# @Time: 2018/6/23
# @Author: xfLi
import torch
import torch.nn as nn
import math
import torchvision.transforms as transforms
import torchvision as tv
from torch.utils.data import DataLoader
model_path = './model_pth/vgg16_bn-6c64b313.pth'
BATCH_SIZE = 1
LR = 0.01
EPOCH = 1
class VGG(nn.Module):
def __init__(self, features, num_classes=10): #构造函数
super(VGG, self).__init__()
# 网络结构(仅包含卷积层和池化层,不包含分类器)
self.features = features
self.classifier = nn.Sequential( #分类器结构
#fc6
nn.Linear(512*7*7, 4096),
nn.ReLU(),
nn.Dropout(),
#fc7
nn.Linear(4096, 4096),
nn.ReLU(),
nn.Dropout(),
#fc8
nn.Linear(4096, num_classes))
# 初始化权重
self._initialize_weights()
def forward(self, x):
x
【PyTorch】VGG16分类
最新推荐文章于 2024-08-06 17:21:01 发布
本文档展示了如何使用PyTorch 0.4.0构建VGG16网络模型,并应用于CIFAR10数据集的图像分类任务。通过定义网络结构、数据预处理、训练和测试过程,详细解释了模型的搭建和训练流程。
摘要由CSDN通过智能技术生成