train

from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.utils as vutils
import time
import numpy as np
from numpy import *
from data_loader.dataset import train_dataset
from data_loader.dataset import val_dataset
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import torchvision.models.segmentation as models
import cv2
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
from PIL import Image
from os import path

from models.Eca_ASP_v4_2 import eca_ASP_v4_2


parser = argparse.ArgumentParser(description='Training a Eca_ASP_v4 _u_pretrain model')
parser.add_argument('--batch_size', type=int, default=2, help='equivalent to instance normalization with batch_size=1')
parser.add_argument('--niter', type=int, default=200, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate, default=0.0002')
parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam. default=0.5')
parser.add_argument('--manual_seed', type=int, help='manual seed')
parser.add_argument('--num_workers', type=int, default=0, help='how many threads of cpu to use while loading data')
parser.add_argument('--flip', type=int, default=1, help='1 for flipping image randomly, 0 for not')
parser.add_argument('--data_path', default='./data/train_384', help='path to training images')
parser.add_argument('--outf', default='./checkpoint/Eca_ASP_v4_2', help='folder to output images and model checkpoints')
parser.add_argument('--save_epoch', default=1, help='save_epoch')
parser.add_argument('--snapshot', default=100, help='snapshot_save_epoch')
parser.add_argument('--test_step', default=20, help='path to val images')
parser.add_argument('--log_step', default=1, help='path to val images')

parser.add_argument('--size_w', type=int, default=256, help='scale image to this size')
parser.add_argument('--size_h', type=int, default=256, help='scale image to this size')

opt = parser.parse_args()
writer = SummaryWriter()
try:
    os.makedirs(opt.outf)
    os.makedirs(opt.outf + '/model/')
    os.makedirs(opt.outf + '/outpic&label/')
except OSError:
    pass
if opt.manual_seed is None:
    opt.manual_seed = random.randint(1, 10000)
random.seed(opt.manual_seed)
torch.manual_seed(opt.manual_seed)
torch.cuda.manual_seed(opt.manual_seed)
cudnn.benchmark = True
print(opt)
print("Random Seed: ", opt.manual_seed)

train_datatset_ = train_dataset(opt.data_path, opt.size_w, opt.size_h, opt.flip)
train_loader = torch.utils.data.DataLoader(dataset=train_datatset_, batch_size=opt.batch_size, shuffle=True,num_workers=opt.num_workers)


net = eca_ASP_v4_2(layers=50,  classes=1, pretrained=True,use_aux=True)
net.cuda()
###########   LOSS & OPTIMIZER   ##########

#criterion = nn.BCEWithLogitsLoss()
criterion = nn.BCELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=opt.lr,momentum=0.9,weight_decay=0.0005) #SGD 使用snapshot+CosineAnnealingLR
#optimizer = torch.optim.Adam(net.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
#scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20], gamma=0.1)

# snapshot 学习率分割 **************************************************************************************

min_lr = 0.0001
scheduler_step= opt.snapshot + 1  #snapshot 学习率分割
iteration = len(train_loader) * opt.snapshot
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = iteration,eta_min=min_lr) #T_max 对应1/2个cos周期所对应的epoch数值,eta_min 为最小的lr值,默认为0

# snapshot 学习率分割 **************************************************************************************
###########   ----------------  ###########

writer.add_scalar('val_overall_iou', 0, 0)
writer.add_scalar('val_overall_acc', 0, 0)
if __name__ == '__main__':

    log = open('%s/train_Unet_log.txt'%(opt.outf), 'w')
    log.write('"Random Seed:%d "' % (opt.manual_seed) + '\n')
    log1 = open('%s/val_Unet_log.txt'%(opt.outf), 'w')
    start = time.time()
    net.train()
    count = 0  # tensorboard test记录
    countval = 0  # tensorboard val记录
    best_iou = 0  # iou 记录
    best_acc = 0  # oa记录

    for epoch in range(1, opt.niter + 1):
        loader = iter(train_loader)

        # snapshot 学习率分割 **************************************************************************************

        if (epoch) % scheduler_step == 0:
            optimizer = torch.optim.SGD(net.parameters(), lr=opt.lr, momentum=0.9, weight_decay=0.0005)
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = iteration,eta_min=min_lr)

        # snapshot 学习率分割 **************************************************************************************

        net.train()

        for i in range(0, train_datatset_.__len__(), opt.batch_size):
            net.train()
            initial_image_, semantic_image_, name = loader.next()

            initial_image = initial_image_.cuda()
            semantic_image = semantic_image_.cuda()

            semantic_image_pred,aux = net(initial_image)

            main_loss = criterion(semantic_image_pred.view(-1), semantic_image.view(-1))
            aux_loss = criterion(aux.view(-1), semantic_image.view(-1))
            loss = main_loss + 0.4 * aux_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            scheduler.step()  # 1.1.0放在后面 CosineAnnealingLR调整使用

            ########### Logging ##########
            if i % (opt.batch_size * 20) == 0:
                writer.add_scalar('loss_step%d/train' % (opt.batch_size * 20), loss.item(), (count + i))

            if i % opt.log_step == 0:
                print('[%d/%d][%d/%d] Loss: %.4f LR: %.6f' %(epoch, opt.niter, i, len(train_loader) * opt.batch_size, loss.item(), optimizer.state_dict()['param_groups'][0]['lr']))
            if i % (opt.batch_size * 100) == 0:
                log.write('[%d/%d][%d/%d] Loss: %.4f' % (epoch, opt.niter, i, len(train_loader) * opt.batch_size, loss.item()) + '\n')
        count = len(train_loader) * opt.batch_size * epoch
        #scheduler.step()  # 1.1.0放在后面

        #if epoch % opt.save_epoch == 0:
            #torch.save(net.state_dict(), '%s/model/netG_%s.pth' % (opt.outf, str(epoch)))

        if epoch % opt.save_epoch == 0:
            net.eval()
            sumval = 0  # val_loss 记录
            transform1 = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.31701732, 0.32337377, 0.28925751],
                                     std=[0.17323045, 0.16700189, 0.16922423])  # 标准化至[-1,1]
            ]
            )
            with torch.no_grad():

                ref_folder =  './data/train/vallabel/'
                file_names = os.listdir(ref_folder)
                pred_folder =  './data/train/val/'
                num_val = len(file_names)

                inters_acum = 0
                union_acum = 0
                correct_acum = 0
                total_acum = 0

                result = "pic\t\t\tIoU %\tacc %\n"

                for i in range(num_val):

                    ref = (np.array(Image.open(ref_folder+ str(i) + '.tif')) / 255.).astype(np.uint8)

                    pred = Image.open(pred_folder+ str(i) + '.tif').convert('RGB')
                    pred = transform1(pred)
                    pred = pred.unsqueeze(0)
                    pred = pred.cuda()
                    pred,aux = net(pred)
                    pred = pred.squeeze(0)
                    pred[pred >= 0.5] = 1
                    pred[pred < 0.5] = 0

                    if i % opt.test_step == 0:
                        a= pred.cpu()
                        a = transforms.ToPILImage()(a)
                        a.save(opt.outf + '/outpic&label/epoch_%d_%d.tif' % (epoch, i) )
                        if epoch==opt.save_epoch :
                            b = Image.open(ref_folder+ str(i) + '.tif')
                            b.save(opt.outf + '/outpic&label/label_%d.tif' % (i))
                    #piccount += 1
                    pred = pred.cpu()
                    pred = (np.array(pred).astype(np.uint8))
                    inters = ref & pred
                    union = ref | pred
                    correct = ref == pred

                    inters_count = np.count_nonzero(inters)
                    union_count = np.count_nonzero(union)
                    correct_count = np.count_nonzero(correct)
                    total_count = ref.size

                    inters_acum += inters_count
                    union_acum += union_count
                    correct_acum += correct_count
                    total_acum += total_count
                    if float(union_count)==0:
                        iou = 0
                    else:
                        iou = inters_count / float(union_count)
                    acc = correct_count / float(total_count)

                    result += "{0}{1}\t\t{2}%\t{3}%\n".format(i,'.tif', round(iou * 100, 2), round(acc * 100, 2))

                overall_iou = inters_acum / float(union_acum)
                overall_acc = correct_acum / float(total_acum)

                result += "{0}\t{1}%\t{2}%\n".format("Overall", round(overall_iou * 100, 2),
                                                       round(overall_acc * 100, 2))

            print("#####################\n" + result + "#####################\n")
            final = ("\n#####################\n" + 'epoch:%d\n'%epoch)
            final += (result + "#####################\n")
            with open(opt.outf+"./eval.txt", "a") as evalfile:
                evalfile.write(final)
            writer.add_scalar('val_overall_iou', round(overall_iou * 100, 2), epoch)
            writer.add_scalar('val_overall_acc', round(overall_acc * 100, 2), epoch)
            if overall_iou>=best_iou:
                torch.save(net.state_dict(), '%s/model/netG_best_iou.pth' % (opt.outf))
                best_iou = overall_iou
            if overall_acc >= best_acc:
                torch.save(net.state_dict(), '%s/model/netG_best_acc.pth' % (opt.outf))
                best_acc = overall_acc

    end = time.time()
    torch.save(net.state_dict(), '%s/model/netG_final.pth' % opt.outf)
    print('Program processed ', end - start, 's, ', (end - start) / 60, 'min, ', (end - start) / 3600, 'h')
    log.close()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值