CNN代码部分
1. 导入必须要的包
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader
from torchvision.datasets import DatasetFolder
from tqdm.auto import tqdm
import torch.optim as optim
2. 定义CNN网络
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.cnn_layers = nn.Sequential(
nn.Conv2d(3, 64, 3, 1, 1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2, 2, 0),
nn.Conv2d(64, 128, 3, 1, 1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2, 2, 0),
nn.Conv2d(128, 256, 3, 1, 1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.MaxPool2d(4, 4, 0)
)
self.fc_layers = nn.Sequential(
nn.Linear(256 * 8 * 8, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 11)
)
def forward(self, x):
x = self.cnn_layers(x)
x = x.flatten(1)
x = self.fc_layers(x)
return x
3. 定义有关tfm, DataLoader的数据加载和数据整理
def get_tfm():
train_tfm = transforms.Compose([
transforms.Resize((128, 128)),
# 可以加入其他变换
transforms.ToTensor()
])
test_tfm = transforms.Compose([
transforms.Resize((128, 128)),
# 验证集和测试集不需要数据增强
transforms.ToTensor()
])
return train_tfm, test_tfm
def prep_dataloader(path, mode, batch_size, tfm):
dataset = DatasetFolder(path, loader = lambda x: Image.open(x), extensions = 'jpg', transform = tfm)
dataloader = DataLoader(dataset, batch_size, shuffle = (mode != 'test'))
return dataloader
4. 定义config
class config():
def __init__(self):
self.seed = 0
self.n_epoch = 80
self.batch_size = 128
self.lr = 0.0003
self.weight_decay = 1e-5
self.device = torch.device('cpu')
5. 定义训练
def train(tr_set, dv_set, model, cfg):
opt = optim.Adam(model.parameters(), lr = cfg.lr)
output_model = None
for epoch in range(cfg.n_epoch):
model.train()
train_loss = []
train_acc = []
for batch in tqdm(tr_set):
imgs, labels = batch
logits = model(imgs.to(cfg.device))
loss = nn.CrossEntropyLoss()(logits, labels.to(cfg.device))
opt.zero_grad()
loss.backward()
grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm = 10)
opt.step()
acc = (logits.argmax(dim = -1) == labels.to(cfg.device)).float().mean()
train_loss.append(loss.item())
train_acc.append(acc)
train_loss = sum(train_loss) / len(train_loss)
train_acc = sum(train_acc) / len(train_acc)
print('epoch : %d, train_acc : %f, train_loss : %f '%(epoch + 1, train_acc, train_loss))
model.eval()
valid_loss = []
valid_acc = []
for batch in tqdm(dv_set):
imgs, labels = batch
with torch.no_grad():
logits = model(imgs.to(cfg.device))
loss = nn.CrossEntropyLoss()(logits, labels.to(cfg.device))
acc = (logits.argmax(dim = -1) == labels.to(cfg.device)).float().mean()
valid_loss.append(loss.item())
valid_acc.append(acc)
valid_loss = sum(valid_loss) / len(valid_loss)
valid_acc = sum(valid_acc) / len(valid_acc)
print('epoch : %d, valid_acc : %f, valid_loss : %f'%(epoch + 1, valid_acc, valid_loss))
6. 定义测试
def test(tt_set, model, cfg):
preds = []
for batch in tqdm(tt_set):
imgs, labels = batch
with torch.no_grad():
logits = model(imgs.to(cfg.device))
preds.extend(logits.argmax(dim = -1).cpu().numpy().tolist())
return preds
7. main函数以及运行结果
path = './data/food-11'
cfg = config()
train_tfm, test_tfm = get_tfm()
tr_set = prep_dataloader(path, 'train', cfg.batch_size, train_tfm)
dv_set = prep_dataloader(path, 'dev', cfg.batch_size, test_tfm)
tt_set = prep_dataloader(path, 'test', cfg.batch_size, test_tfm)
model = CNN().to(cfg.device)
train(tr_set, dv_set, model, cfg)