from __future__ import division
from torchvision import models
from torchvision import transforms
from PIL import Image
import torch
import torchvision
import torch.nn as nn
import numpy as np
from torch.utils import data
from torch import optim
from logger import Logger
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_epochs=80
learning_rate = 0.001
#Image Processing
transform = transforms.Compose([
transforms.Pad(4),
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32),
transforms.ToTensor(),
])
# CIFAR10
train_dataset = torchvision.datasets.CIFAR10(root='./data/', train=True, transform=transform, download=False)
test_dataset = torchvision.datasets.CIFAR10(root='./data/', train=False, transform=transforms.ToTensor())
# Data Loader
train_loader = data.DataLoader(train_dataset, 100, True)
test_loader = data.DataLoader(test_dataset, 100, False)
def Conv3(in_channels, out_channels, stride=1):
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
# Residual block
class ResidualBlock(nn.Module):
# self 参数代表把该类创建的实例传进去 实例调用的方法参数都应包含self参数 类方法不需要
# 定义网络用到的模块, 具体怎么传播要看前馈函数, 下采样目的是使x通道数和残差模块输出通道数相同, 或者步长不同时使二者相同
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(ResidualBlock, self).__init__()
self.conv1 = Conv3(in_channels, out_channels, stride)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = Conv3(out_channels, out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = downsample
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
# resnet
class ResNet(nn.Module):
# block 指定采用什么样残差模块, layers[i]指定第i组卷积blocks包含多少个残差模块
def __init__(self, block, layers, num_class=10):
super(ResNet, self).__init__()
self.in_channels = 16
self.conv = Conv3(3, 16)
self.bn = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self.make_layer(block, 16, layers[0])
self.layer2 = self.make_layer(block, 32, layers[1], 2)
self.layer3 = self.make_layer(block, 64, layers[2], 2)
self.avg_pool = nn.AvgPool2d(8)
self.fc = nn.Linear(64, num_class)
def make_layer(self, block, out_channel, blocks, stride=1):
downsample = None
if stride != 1 or self.in_channels != out_channel:
downsample = nn.Sequential(
Conv3(self.in_channels, out_channel, stride),
nn.BatchNorm2d(out_channel),
)
layers = []
layers.append(block(self.in_channels, out_channel, stride, downsample))
# 除了承接第一次输入的block inchannel不固定外, 其他都是相同的inchanel, outchannel, 所以后面的也不需要对x downsample
self.in_channels = out_channel
for i in range(1, blocks):
layers.append(block(out_channel, out_channel))
# 加*把list分解成多个元素传入
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv(x)
out = self.bn(out)
out = self.relu(out)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.avg_pool(out)
# [batch channel width height]-> [batch, -1]
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
model = ResNet(ResidualBlock, [2, 2, 2]).to(device)
# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), learning_rate)
# updating learning rate
def update_lr(optimizer, lr):
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# traing model
total_step = len(train_loader)
curr_lr = learning_rate
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
# backward
model.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print('Epoch[{}/{}] Loss {:.4f} LearningRate{:.4f}'
.format(epoch+1, num_epochs, loss.item(), optimizer.param_groups))
# 每隔20epoch 缩减一次学习率
if (epoch+1)%20 == 0:
curr_lr /= 3
update_lr(optimizer, curr_lr)
# tesing
model.eval()
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predict = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predict == labels).sum().item()
print('acc {}%'.format(100*correct/total))
torch.save(model.state_dict(), 'resnet.ckpt')