T
r
a
i
n
Train
Train
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from datetime import datetime
from major_dataset import LoadDataset
from major_evalution import eval_semantic_segmentation
import major_config
Load_train = LoadDataset([major_config.train_image, major_config.train_label], major_config.crop_size)
Load_val = LoadDataset([major_config.val_image, major_config.val_label], major_config.crop_size)
train_data = DataLoader(Load_train, batch_size=major_config.batchsize, shuffle=True, num_workers=1)
val_data = DataLoader(Load_val, batch_size=major_config.batchsize, shuffle=True, num_workers=1)
net = major_config.model
net = net.to(major_config.device)
criterion = nn.NLLLoss().to(major_config.device)
optimizer = optim.Adam(net.parameters(), lr=1e-4)
def train(model):
best = [0]
net = model.train()
for epoch in range(major_config.num_epoch):
print('Epoch is [{}/{}]'.format(epoch + 1, major_config.num_epoch))
if epoch % 20 == 0 and epoch != 0:
for group in optimizer.param_groups:
group['lr'] *= 0.5
train_loss = 0
train_pa = 0
train_mpa = 0
train_miou = 0
train_fwiou = 0
for i, sample in enumerate(train_data):
img_data = sample['img'].to(major_config.device)
img_label = sample['label'].to(major_config.device)
out = net(img_data)
out = F.log_softmax(out, dim=1)
loss = criterion(out, img_label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
pre_label = out.max(dim=1)[1].data.cpu().numpy()
pre_label = [i for i in pre_label]
true_label = img_label.data.cpu().numpy()
true_label = [i for i in true_label]
eval_metrix = eval_semantic_segmentation(pre_label, true_label)
train_pa += eval_metrix['pa']
train_mpa += eval_metrix['mpa']
train_miou += eval_metrix['miou']
train_fwiou += eval_metrix['fwiou']
print('|batch[{}/{}]|batch_loss {: .8f}|'.format(i + 1, len(train_data), loss.item()))
metric_description = '|Train PA|: {:.5f}|\n|Train MPA|: {:.5f}|\n|Train MIou|: {:.5f}|\n|Train FWIou|: {:.5f}|'.format(
train_pa / len(train_data),
train_mpa / len(train_data),
train_miou / len(train_data),
train_fwiou / len(train_data),
)
print(metric_description)
if max(best) <= train_miou / len(train_data):
best.append(train_miou / len(train_data))
torch.save(net.state_dict(), major_config.path_saved_model)
def evaluate(model):
net = model.eval()
eval_loss = 0
eval_acc = 0
eval_miou = 0
eval_class_acc = 0
prec_time = datetime.now()
for j, sample in enumerate(val_data):
valImg = sample['img'].to(major_config.device)
valLabel = sample['label'].long().to(major_config.device)
out = net(valImg)
out = F.log_softmax(out, dim=1)
loss = criterion(out, valLabel)
eval_loss = loss.item() + eval_loss
pre_label = out.max(dim=1)[1].data.cpu().numpy()
pre_label = [i for i in pre_label]
true_label = valLabel.data.cpu().numpy()
true_label = [i for i in true_label]
eval_metrics = eval_semantic_segmentation(pre_label, true_label)
eval_acc = eval_metrics['mean_class_accuracy'] + eval_acc
eval_miou = eval_metrics['miou'] + eval_miou
cur_time = datetime.now()
h, remainder = divmod((cur_time - prec_time).seconds, 3600)
m, s = divmod(remainder, 60)
time_str = 'Time: {:.0f}:{:.0f}:{:.0f}'.format(h, m, s)
val_str = ('|Valid Loss|: {:.5f} \n|Valid Acc|: {:.5f} \n|Valid Mean IU|: {:.5f} \n|Valid Class Acc|:{:}'.format(
eval_loss / len(train_data),
eval_acc / len(val_data),
eval_miou / len(val_data),
eval_class_acc / len(val_data)))
print(val_str)
print(time_str)
if __name__ == "__main__":
train(net)