(Pytorch)VGGNet代码复现CIFAR-10数据集

本文介绍了如何使用VGG-16和VGG-19架构在PyTorch中实现卷积神经网络,通过CIFAR-10数据集进行图像分类,并展示了训练过程、参数调整和模型验证。使用了数据预处理、损失函数、优化器和学习率调度等关键步骤。
摘要由CSDN通过智能技术生成

model.py

# Define VGG-16 and VGG-19.
import torch

cfg = {
    'VGG-16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG-19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
}


# VGG-16 and VGG-19
class VGGNet(torch.nn.Module):
    def __init__(self, VGG_type, num_classes):
        super(VGGNet, self).__init__()
        self.features = self._make_layers(cfg[VGG_type])
        self.classifier = torch.nn.Linear(512, num_classes)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':  # MaxPool2d
                layers += [torch.nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [torch.nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           torch.nn.BatchNorm2d(x),
                           torch.nn.ReLU(inplace=True)]
                in_channels = x
        layers += [torch.nn.AvgPool2d(kernel_size=1, stride=1)]
        return torch.nn.Sequential(*layers)  # The number of parameters is more than one.

train.py

# import packages
import os
import sys

import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from VGGnet.model import VGGNet

# Hyper-parameters
epochs = 300
batch_size = 100
learning_rate = 0.01
num_classes = 10
# Transform configuration and Data Augmentation.
transform_train = torchvision.transforms.Compose([torchvision.transforms.Pad(4),
                                                  torchvision.transforms.RandomHorizontalFlip(),
                                                  torchvision.transforms.RandomCrop(32),
                                                  torchvision.transforms.ToTensor(),
                                                  torchvision.transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

transform_test = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                                 torchvision.transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
# Load downloaded dataset.
train_dataset = torchvision.datasets.CIFAR10(root='data', download=True, train=True, transform=transform_train)
val_dataset = torchvision.datasets.CIFAR10(root='data', download=True, train=False, transform=transform_test)
# Data Loader.
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)
# Make model.
net_name = 'VGG-16'
# net_name = 'VGG-19'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VGGNet(net_name, num_classes).to(device)
# Loss and optimizer.
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

train_num = len(train_dataset)
val_num = len(val_dataset)
train_steps = len(train_loader)
val_steps = len(val_loader)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[136, 185], gamma=0.1)

resume = True  # 设置是否需要从上次的状态继续训练
if resume:
    if os.path.isfile("VGGnet.pth"):
        print("Resume from checkpoint...")
        checkpoint = torch.load("VGGnet.pth")
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        initepoch = checkpoint['epoch'] + 2
        print("====>loaded checkpoint (epoch{})".format(checkpoint['epoch'] + 1))
    else:
        print("====>no checkpoint found.")
        initepoch = 1  # 如果没进行训练过,初始训练epoch值为1

writer = SummaryWriter("logs")

for epoch in range(initepoch - 1, epochs):
    # train
    print("-------第 {} 轮训练开始-------".format(epoch + 1))
    model.train()
    train_acc = 0.0
    running_loss = 0.0
    train_bar = tqdm(train_loader, file=sys.stdout)
    for step, data in enumerate(train_bar):
        images, labels = data
        optimizer.zero_grad()
        outputs = model(images.to(device))
        loss = loss_function(outputs, labels.to(device))
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss)
        _, predict = torch.max(outputs, dim=1)
        train_acc += torch.eq(predict, labels.to(device)).sum().item()
    train_loss = running_loss / train_steps
    train_accurate = train_acc / train_num
    # val
    model.eval()
    val_acc = 0.0
    running_loss = 0.0
    with torch.no_grad():
        val_bar = tqdm(val_loader, file=sys.stdout)
        for step, val_data in enumerate(val_bar):
            val_images, val_labels = val_data
            outputs = model(val_images.to(device))
            loss = loss_function(outputs, val_labels.to(device))
            running_loss += loss.item()
            _, predict = torch.max(outputs, dim=1)
            val_acc += torch.eq(predict, val_labels.to(device)).sum().item()
    val_loss = running_loss / val_steps
    val_accurate = val_acc / val_num
    scheduler.step()
    print('[epoch %d] train_loss: %.3f val_loss:%.3f train_accuracy:%.3f val_accuracy: %.3f' %
          (epoch + 1, train_loss, val_loss, train_accurate, val_accurate))
    writer.add_scalars('loss',
                       {'train': train_loss, 'val': val_loss}, global_step=epoch)
    writer.add_scalars('acc',
                       {'train': train_accurate, 'val': val_accurate}, global_step=epoch)
    # 保存断点
    checkpoint = {"model_state_dict": model.state_dict(),
                  "optimizer_state_dict": optimizer.state_dict(),
                  "epoch": epoch}
    path_checkpoint = "VGGnet.pth"
    torch.save(checkpoint, path_checkpoint)
    print("保存模型成功")
print('Finished Training')
writer.close()

 程序设置了断点续训,可以接着训练,查看日志可以用tensorboard

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值