from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.utils as vutils
import time
import numpy as np
from numpy import *
from data_loader.dataset import train_dataset
from data_loader.dataset import val_dataset
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import torchvision.models.segmentation as models
import cv2
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
from PIL import Image
from os import path
from models.Eca_ASP_v4_2 import eca_ASP_v4_2
parser = argparse.ArgumentParser(description='Training a Eca_ASP_v4 _u_pretrain model')
parser.add_argument('--batch_size', type=int, default=2, help='equivalent to instance normalization with batch_size=1')
parser.add_argument('--niter', type=int, default=200, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam. default=0.5')
parser.add_argument('--manual_seed', type=int, help='manual seed')
parser.add_argument('--num_workers', type=int, default=0, help='how many threads of cpu to use while loading data')
parser.add_argument('--flip', type=int, default=1, help='1 for flipping image randomly, 0 for not')
parser.add_argument('--data_path', default='./data/train_384', help='path to training images')
parser.add_argument('--outf', default='./checkpoint/Eca_ASP_v4_2', help='folder to output images and model checkpoints')
parser.add_argument('--save_epoch', default=1, help='save_epoch')
parser.add_argument('--snapshot', default=100, help='snapshot_save_epoch')
parser.add_argument('--test_step', default=20, help='path to val images')
parser.add_argument('--log_step', default=1, help='path to val images')
parser.add_argument('--size_w', type=int, default=256, help='scale image to this size')
parser.add_argument('--size_h', type=int, default=256, help='scale image to this size')
opt = parser.parse_args()
writer = SummaryWriter()
try:
os.makedirs(opt.outf)
os.makedirs(opt.outf + '/model/')
os.makedirs(opt.outf + '/outpic&label/')
except OSError:
pass
if opt.manual_seed is None:
opt.manual_seed = random.randint(1, 10000)
random.seed(opt.manual_seed)
torch.manual_seed(opt.manual_seed)
torch.cuda.manual_seed(opt.manual_seed)
cudnn.benchmark = True
print(opt)
print("Random Seed: ", opt.manual_seed)
train_datatset_ = train_dataset(opt.data_path, opt.size_w, opt.size_h, opt.flip)
train_loader = torch.utils.data.DataLoader(dataset=train_datatset_, batch_size=opt.batch_size, shuffle=True,num_workers=opt.num_workers)
net = eca_ASP_v4_2(layers=50, classes=1, pretrained=True,use_aux=True)
net.cuda()
########### LOSS & OPTIMIZER ##########
#criterion = nn.BCEWithLogitsLoss()
criterion = nn.BCELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=opt.lr,momentum=0.9,weight_decay=0.0005) #SGD 使用snapshot+CosineAnnealingLR
#optimizer = torch.optim.Adam(net.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
#scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20], gamma=0.1)
# snapshot 学习率分割 **************************************************************************************
min_lr = 0.0001
scheduler_step= opt.snapshot + 1 #snapshot 学习率分割
iteration = len(train_loader) * opt.snapshot
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = iteration,eta_min=min_lr) #T_max 对应1/2个cos周期所对应的epoch数值,eta_min 为最小的lr值,默认为0
# snapshot 学习率分割 **************************************************************************************
########### ---------------- ###########
writer.add_scalar('val_overall_iou', 0, 0)
writer.add_scalar('val_overall_acc', 0, 0)
if __name__ == '__main__':
log = open('%s/train_Unet_log.txt'%(opt.outf), 'w')
log.write('"Random Seed:%d "' % (opt.manual_seed) + '\n')
log1 = open('%s/val_Unet_log.txt'%(opt.outf), 'w')
start = time.time()
net.train()
count = 0 # tensorboard test记录
countval = 0 # tensorboard val记录
best_iou = 0 # iou 记录
best_acc = 0 # oa记录
for epoch in range(1, opt.niter + 1):
loader = iter(train_loader)
# snapshot 学习率分割 **************************************************************************************
if (epoch) % scheduler_step == 0:
optimizer = torch.optim.SGD(net.parameters(), lr=opt.lr, momentum=0.9, weight_decay=0.0005)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = iteration,eta_min=min_lr)
# snapshot 学习率分割 **************************************************************************************
net.train()
for i in range(0, train_datatset_.__len__(), opt.batch_size):
net.train()
initial_image_, semantic_image_, name = loader.next()
initial_image = initial_image_.cuda()
semantic_image = semantic_image_.cuda()
semantic_image_pred,aux = net(initial_image)
main_loss = criterion(semantic_image_pred.view(-1), semantic_image.view(-1))
aux_loss = criterion(aux.view(-1), semantic_image.view(-1))
loss = main_loss + 0.4 * aux_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step() # 1.1.0放在后面 CosineAnnealingLR调整使用
########### Logging ##########
if i % (opt.batch_size * 20) == 0:
writer.add_scalar('loss_step%d/train' % (opt.batch_size * 20), loss.item(), (count + i))
if i % opt.log_step == 0:
print('[%d/%d][%d/%d] Loss: %.4f LR: %.6f' %(epoch, opt.niter, i, len(train_loader) * opt.batch_size, loss.item(), optimizer.state_dict()['param_groups'][0]['lr']))
if i % (opt.batch_size * 100) == 0:
log.write('[%d/%d][%d/%d] Loss: %.4f' % (epoch, opt.niter, i, len(train_loader) * opt.batch_size, loss.item()) + '\n')
count = len(train_loader) * opt.batch_size * epoch
#scheduler.step() # 1.1.0放在后面
#if epoch % opt.save_epoch == 0:
#torch.save(net.state_dict(), '%s/model/netG_%s.pth' % (opt.outf, str(epoch)))
if epoch % opt.save_epoch == 0:
net.eval()
sumval = 0 # val_loss 记录
transform1 = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.31701732, 0.32337377, 0.28925751],
std=[0.17323045, 0.16700189, 0.16922423]) # 标准化至[-1,1]
]
)
with torch.no_grad():
ref_folder = './data/train/vallabel/'
file_names = os.listdir(ref_folder)
pred_folder = './data/train/val/'
num_val = len(file_names)
inters_acum = 0
union_acum = 0
correct_acum = 0
total_acum = 0
result = "pic\t\t\tIoU %\tacc %\n"
for i in range(num_val):
ref = (np.array(Image.open(ref_folder+ str(i) + '.tif')) / 255.).astype(np.uint8)
pred = Image.open(pred_folder+ str(i) + '.tif').convert('RGB')
pred = transform1(pred)
pred = pred.unsqueeze(0)
pred = pred.cuda()
pred,aux = net(pred)
pred = pred.squeeze(0)
pred[pred >= 0.5] = 1
pred[pred < 0.5] = 0
if i % opt.test_step == 0:
a= pred.cpu()
a = transforms.ToPILImage()(a)
a.save(opt.outf + '/outpic&label/epoch_%d_%d.tif' % (epoch, i) )
if epoch==opt.save_epoch :
b = Image.open(ref_folder+ str(i) + '.tif')
b.save(opt.outf + '/outpic&label/label_%d.tif' % (i))
#piccount += 1
pred = pred.cpu()
pred = (np.array(pred).astype(np.uint8))
inters = ref & pred
union = ref | pred
correct = ref == pred
inters_count = np.count_nonzero(inters)
union_count = np.count_nonzero(union)
correct_count = np.count_nonzero(correct)
total_count = ref.size
inters_acum += inters_count
union_acum += union_count
correct_acum += correct_count
total_acum += total_count
if float(union_count)==0:
iou = 0
else:
iou = inters_count / float(union_count)
acc = correct_count / float(total_count)
result += "{0}{1}\t\t{2}%\t{3}%\n".format(i,'.tif', round(iou * 100, 2), round(acc * 100, 2))
overall_iou = inters_acum / float(union_acum)
overall_acc = correct_acum / float(total_acum)
result += "{0}\t{1}%\t{2}%\n".format("Overall", round(overall_iou * 100, 2),
round(overall_acc * 100, 2))
print("#####################\n" + result + "#####################\n")
final = ("\n#####################\n" + 'epoch:%d\n'%epoch)
final += (result + "#####################\n")
with open(opt.outf+"./eval.txt", "a") as evalfile:
evalfile.write(final)
writer.add_scalar('val_overall_iou', round(overall_iou * 100, 2), epoch)
writer.add_scalar('val_overall_acc', round(overall_acc * 100, 2), epoch)
if overall_iou>=best_iou:
torch.save(net.state_dict(), '%s/model/netG_best_iou.pth' % (opt.outf))
best_iou = overall_iou
if overall_acc >= best_acc:
torch.save(net.state_dict(), '%s/model/netG_best_acc.pth' % (opt.outf))
best_acc = overall_acc
end = time.time()
torch.save(net.state_dict(), '%s/model/netG_final.pth' % opt.outf)
print('Program processed ', end - start, 's, ', (end - start) / 60, 'min, ', (end - start) / 3600, 'h')
log.close()
train
最新推荐文章于 2023-03-31 10:21:22 发布