在CIFAR-10上训练VGG6

import os
import torch
from torch.utils.data import Dataset
from torchvision.io import read_image
from torch.utils.data import DataLoader
import sys
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
from typing import List, cast

classes = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
class Cinic10(Dataset):
    def __init__(self, img_dir):
        self.img_labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
        self.label2id = {'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4,
                         'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}
        self.id2label = {0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer',
                         5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}
        self.img_dir = img_dir
        self.cinic_mean = [0.47889522, 0.47227842, 0.43047404]
        self.cinic_std = [0.24205776, 0.23828046, 0.25874835]
        self.transform = transforms.Compose([transforms.Normalize(mean=self.cinic_mean, std=self.cinic_std)])
        self.countPerLabel = len(
            [name for name in os.listdir(os.path.join(self.img_dir, self.img_labels[0]))
             if os.path.isfile(os.path.join(os.path.join(self.img_dir, self.img_labels[0]), name))])
        self.len = len(self.img_labels) * self.countPerLabel
        self.X_Y = []
        for label in self.img_labels:
            img_path = os.path.join(self.img_dir, label)
            images_files = [name for name in os.listdir(img_path) if os.path.isfile(os.path.join(img_path, name))]
            label_id = self.label2id[label]
            for images_file in images_files:
                image = read_image(os.path.join(img_path, images_file))
                if image.shape != torch.Size([3, 32, 32]):
                    image = torch.cat([image, image, image])
                image = image.type(torch.float32)
                image = self.transform(image)
                self.X_Y.append([image, label_id])

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        [image, label] = self.X_Y[idx]
        return image, label


class Cifar10(Dataset):
    def __init__(self, img_dir):
        self.img_labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
        self.label2id = {'airplane': 0, 'automobile': 1, 'bird': 2, 'cat': 3, 'deer': 4,
                         'dog': 5, 'frog': 6, 'horse': 7, 'ship': 8, 'truck': 9}
        self.id2label = {0: 'airplane', 1: 'automobile', 2: 'bird', 3: 'cat', 4: 'deer',
                         5: 'dog', 6: 'frog', 7: 'horse', 8: 'ship', 9: 'truck'}
        self.img_dir = img_dir
        self.cinic_mean = [0.485, 0.456, 0.406]
        self.cinic_std = [0.229, 0.224, 0.225]
        self.transform = transforms.Compose([transforms.Normalize(mean=self.cinic_mean, std=self.cinic_std)])
        self.countPerLabel = len(
            [name for name in os.listdir(os.path.join(self.img_dir, self.img_labels[0]))
             if os.path.isfile(os.path.join(os.path.join(self.img_dir, self.img_labels[0]), name))])
        self.len = len(self.img_labels) * self.countPerLabel
        self.X_Y = []
        for label in self.img_labels:
            img_path = os.path.join(self.img_dir, label)
            images_files = [name for name in os.listdir(img_path) if os.path.isfile(os.path.join(img_path, name))]
            label_id = self.label2id[label]
            for images_file in images_files:
                image = read_image(os.path.join(img_path, images_file))
                if image.shape != torch.Size([3, 32, 32]):
                    image = torch.cat([image, image, image])
                image = image.type(torch.float32)
                image = self.transform(image)
                self.X_Y.append([image, label_id])

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        [image, label] = self.X_Y[idx]
        return image, label


traindataset = Cifar10('CIFAR10/train')
traindataloader = DataLoader(traindataset, batch_size=256, shuffle=True)
testdataset = Cifar10('CIFAR10/test')
testdataloader = DataLoader(testdataset, batch_size=256)


def make_layers():
    layers: List[nn.Module] = []
    in_channels = 3
    cfg = [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"]
    for v in cfg:
        if v == "M":
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            v = cast(int, v)
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)


class VGG(nn.Module):
    def __init__(self, features, num_classes, dropout):
        super().__init__()
        self.features = features
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(p=dropout),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(p=dropout),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


VGG16_layers = [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"]
device = torch.device('cuda:1')
net = VGG(make_layers(), len(classes), 0.5)
net = net.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.05, momentum=0.9, weight_decay=5e-4)


def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 2 every 30 epochs"""
    lr = 0.05 * (0.5 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def train(dataloader, net, criterion, optimizer, device):
    net.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for data in tqdm(dataloader, desc='training...', file=sys.stdout):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        running_loss += loss.item()
    return running_loss/total, correct/total

def evaluate(dataloader, net, criterion, device):
    net.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for data in tqdm(dataloader, desc='evaluating...', file=sys.stdout):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            running_loss += loss.item()
    return running_loss/total, correct/total


n_epochs = 300
best_valid_acc = 0
for epoch in range(n_epochs):
    adjust_learning_rate(optimizer, epoch)
    train_loss, train_acc = train(traindataloader, net, criterion, optimizer, device)
    valid_loss, valid_acc = evaluate(testdataloader, net, criterion, device)
    print(f'epoch: {epoch + 1}')
    print(f'train_loss: {train_loss:.3f}, train_acc: {train_acc:.3f}')
    print(f'valid_loss: {valid_loss:.3f}, valid_acc: {valid_acc:.3f}')
    if valid_acc > best_valid_acc:
        print(f'{valid_acc:.3f} is better than {best_valid_acc:.3f}, best valid acc is {valid_acc:.3f}')
        best_valid_acc = valid_acc
        torch.save(net.state_dict(), 'CV/CIFAR10/VGG16.pth')
    else:
        print(f'best valid acc is {best_valid_acc:.3f}')

net2 = VGG(make_layers(), len(classes), 0.5)
net2 = net2.to(device)
net2.load_state_dict(torch.load('CV/CIFAR10/VGG16.pth'))
valid_loss, valid_acc = evaluate(testdataloader, net2, criterion, device)
print(f'best model valid acc: {valid_acc:.3f}')

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值