# -*- coding: utf-8 -*-
'''
@Time : 2020/5/26 20:59
@Author : HHNa
@FileName: 3.model_train_val.py
@Software: PyCharm
'''
import torch
import numpy as np
import torch.nn as nn
import glob, json
from dataset import SVHNDataset
from model import SVHN_Model1
import torchvision.transforms as transforms
from tensorboard_logger import Logger
def train(train_loader, model, criterion, optimizer):
# 切换模型为训练模式
model.train()
train_loss = []
for i, (input, target) in enumerate(train_loader):
c0, c1, c2, c3, c4= model(input)
target = target.long()
loss = criterion(c0, target[:, 0]) + \
criterion(c1, target[:, 1]) + \
criterion(c2, target[:, 2]) + \
criterion(c3, target[:, 3]) + \
criterion(c4, target[:, 4])
loss /= 5
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss.append(loss.item())
if i % 100 ==0:
print(loss.item())
return np.mean(train_loss)
def validate(val_loader, model, criterion):
# 切换模型为预测模型
model.eval()
val_loss = []
# 不记录模型梯度信息
with torch.no_grad():
for i, (input, target) in enumerate(val_loader):
c0, c1, c2, c3, c4 = model(input)
target = target.long()
loss = criterion(c0, target[:, 0]) + \
criterion(c1, target[:, 1]) + \
criterion(c2, target[:, 2]) + \
criterion(c3, target[:, 3]) + \
criterion(c4, target[:, 4])
loss /= 5
val_loss.append(loss.item())
return np.mean(val_loss)
if __name__ == "__main__":
train_path = glob.glob('./data/train/mchar_train/*.png')
train_path.sort()
train_json = json.load(open('./data/train/mchar_train.json'))
train_label = [train_json[x]['label'] for x in train_json]
val_path = glob.glob('./data/val/mchar_val/*.png')
val_path.sort()
val_json = json.load(open('./data/val/mchar_val.json'))
val_label = [train_json[x]['label'] for x in val_json]
train_dataset = SVHNDataset(train_path, train_label,
transforms.Compose([
transforms.Resize((64, 128)),
transforms.ColorJitter(0.3, 0.3, 0.2),
transforms.RandomRotation(5),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]))
val_dataset = SVHNDataset(val_json, val_label,
transforms.Compose([
transforms.Resize((64, 128)),
transforms.ColorJitter(0.3, 0.3, 0.2),
transforms.RandomRotation(5),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]))
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=10,
shuffle=True,
num_workers=10,
)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=10,
shuffle=False,
num_workers=10,
)
model = SVHN_Model1()
criterion = nn.CrossEntropyLoss(size_average=False)
optimizer = torch.optim.Adam(model.parameters(), 0.001)
best_loss = 1000.0
losses = []
for epoch in range(20):
print('Epoch: ', epoch)
train(train_loader, model, criterion, optimizer)
val_loss = validate(val_loader, model, criterion)
# 记录下验证集精度
if val_loss < best_loss:
best_loss = val_loss
torch.save(model.state_dict(), './model.pt')
有点bug: