接着上一篇的DataLoader,这一篇主要是train时候的笔记代码:
import sys
import torch
import argparse
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import shutil
import cv2
from torch.autograd import Variable
from torch.utils import data
import os
os.environ["CUDA_VISIBLE_DEVICES"] ="0"
from dataset import IC15Loader
from dataset import IC15TestLoader
from metrics import runningScore
import models
from tqdm import tqdm
from util import Logger, AverageMeter
import time
import util
import time
from pse import pse
from cal_recall import cal_recall_precison_f1
binary_th = 1
kernel_num = 7
scale = 1
long_size = 2240
min_kernel_area = 5.0
min_area = 800.0
min_score = 0.93
def extend_3c(img):
img = img.reshape(img.shape[0], img.shape[1], 1)
img = np.concatenate((img, img, img), axis=2)
return img
def debug(idx, img_paths, imgs, output_root):
if not os.path.exists(output_root):
os.makedirs(output_root)
col = []
for i in range(len(imgs)):
row = []
for j in range(len(imgs[i])):
# img = cv2.copyMakeBorder(imgs[i][j], 3, 3, 3, 3, cv2.BORDER_CONSTANT, value=[255, 0, 0])
row.append(imgs[i][j])
res = np.concatenate(row, axis=1)
col.append(res)
res = np.concatenate(col, axis=0)
img_name = img_paths[idx].split('/')[-1]
# print (idx, '/', len(img_paths), img_name)
cv2.imwrite(output_root + img_name, res)
def write_result_as_txt(image_name, bboxes, path):
filename = util.io.join_path(path, 'res_%s.txt'%(image_name))
lines = []
for b_idx, bbox in enumerate(bboxes):
values = [int(v) for v in bbox]
line = "%d, %d, %d, %d, %d, %d, %d, %d\n"%tuple(values)
lines.append(line)
util.io.write_lines(filename, lines)
def polygon_from_points(points):
"""
Returns a Polygon object to use with the Polygon2 class from a list of 8 points: x1,y1,x2,y2,x3,y3,x4,y4
"""
resBoxes=np.empty([1, 8],dtype='int32')
resBoxes[0, 0] = int(points[0])
resBoxes[0, 4] = int(points[1])
resBoxes[0, 1] = int(points[2])
resBoxes[0, 5] = int(points[3])
resBoxes[0, 2] = int(points[4])
resBoxes[0, 6] = int(points[5])
resBoxes[0, 3] = int(points[6])
resBoxes[0, 7] = int(points[7])
pointMat = resBoxes[0].reshape([2, 4]).T
return plg.Polygon(pointMat)
def test(model,scale = 1):
data_loader = IC15TestLoader(long_size=long_size)
test_loader = torch.utils.data.DataLoader(
data_loader,
batch_size=1,
shuffle=False,
num_workers=2,
drop_last=True)
model = model.cuda()
model.eval()
total_frame = 0.0
total_time = 0.0
bar = tqdm(total= len(test_loader))
for idx, (org_img, img) in enumerate(test_loader):
sys.stdout.flush()
bar.update(1)
img = Variable(img.cuda())
org_img = org_img.numpy().astype('uint8')[0]
text_box = org_img.copy()
torch.cuda.synchronize()
start = time.time()
with torch.no_grad():
outputs = model(img)
ind = 'cat_34'
cv2.imwrite('text'+str(ind)+'.jpg',outputs[:, 0, :, :].data.cpu().numpy()[0].astype(np.uint8)*255)
cv2.imwrite('kernel'+str(ind)+'.jpg',outputs[:, 6, :, :].data.cpu().numpy()[0].astype(np.uint8)*255)
cv2.imwrite('ori'+str(ind)+'.jpg',org_img)
score = torch.sigmoid(outputs[:, 0, :, :])
outputs = (torch.sign(outputs - binary_th) + 1) / 2
text = outputs[:, 0, :, :]
kernels = outputs[:, 0:kernel_num, :, :] * text
score = score.data.cpu().numpy()[0].astype(np.float32)
text = text.data.cpu().numpy()[0].astype(np.uint8)
kernels = kernels.data.cpu().numpy()[0].astype(np.uint8)
# c++ version pse
pred = pse(kernels,min_kernel_area / (scale * scale))
# python version pse
# pred = pypse(kernels, args.min_kernel_area / (args.scale * args.scale))
# scale = (org_img.shape[0] * 1.0 / pred.shape[0], org_img.shape[1] * 1.0 / pred.shape[1])
scale_im = (org_img.shape[1] * 1.0 / pred.shape[1], org_img.shape[0] * 1.0 / pred.shape[0])
label = pred
label_num = np.max(label) + 1
bboxes = []
for i in range(1, label_num):
points = np.array(np.where(label == i)).transpose((1, 0))[:, ::-1]
if points.shape[0] < min_area / (scale * scale):
continue
score_i = np.mean(score[label == i])
if score_i < min_score:
continue
rect = cv2.minAreaRect(points)
bbox = cv2.boxPoints(rect) * scale_im
bbox = bbox.astype('int32')
bboxes.append(bbox.reshape(-1))
torch.cuda.synchronize()
end = time.time()
total_frame += 1
total_time += (end - start)
# print('fps: %.2f'%(total_frame / total_time))
sys.stdout.flush()
for bbox in bboxes:
cv2.drawContours(text_box, [bbox.reshape(4, 2)], -1, (0, 255, 0), 2)
image_name = data_loader.img_paths[idx].split('/')[-1].split('.')[0]
tp = 34
write_result_as_txt(image_name, bboxes, 'outputs/submit_ic15_cat_'+str(tp)+'/')
text_box = cv2.resize(text_box, (text.shape[1], text.shape[0]))
debug(idx, data_loader.img_paths, [[text_box]], 'outputs/vis_ic15_cat_'+str(tp)+'/')
bar.close()
sys.stdout.flush()
result_dict = cal_recall_precison_f1('/src/notebooks/train_data/ch4_test_gts', 'outputs/submit_ic15_cat_'+str(tp)+'/')
return result_dict['recall'], result_dict['precision'], result_dict['hmean']
def ohem_single(score, gt_text, training_mask):
pos_num = (int)(np.sum(gt_text > 0.5)) - (int)(np.sum((gt_text > 0.5) & (training_mask <= 0.5)))
if pos_num == 0:
# selected_mask = gt_text.copy() * 0 # may be not good
selected_mask = training_mask
selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
return selected_mask
neg_num = (int)(np.sum(gt_text <= 0.5))
neg_num = (int)(min(pos_num * 3, neg_num))
if neg_num == 0:
selected_mask = training_mask
selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
return selected_mask
neg_score = score[gt_text <= 0.5]
neg_score_sorted = np.sort(-neg_score)
threshold = -neg_score_sorted[neg_num - 1]
selected_mask = ((score >= threshold) | (gt_text > 0.5)) & (training_mask > 0.5)
selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
return selected_mask
def ohem_batch(scores, gt_texts, training_masks):
scores = scores.data.cpu().numpy()
gt_texts = gt_texts.data.cpu().numpy()
training_masks = training_masks.data.cpu().numpy()
selected_masks = []
for i in range(scores.shape[0]):
#这里是ohem的处理
# 首先是这样,对于每张图,都会有正例和负例
# 这里ohem的作用是,一般一张图上的负例的数量会比正例多得多(图中的负例较多)
# 训练的时候,只会取出那些预测为负例概率最低的位置的负样本作为这张图的负例(使得正负比例1:3)
selected_masks.append(ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[i, :, :]))
selected_masks = np.concatenate(selected_masks, 0)
selected_masks = torch.from_numpy(selected_masks).float()
return selected_masks
def dice_loss(input, target, mask):
input = torch.sigmoid(input) #input.shape torch.Size([8, 640, 640])
input = input.contiguous().view(input.size()[0], -1) #input.shape torch.Size([8, 409600])
target = target.contiguous().view(target.size()[0], -1)
mask = mask.contiguous().view(mask.size()[0], -1)
input = input * mask #torch.Size([8, 409600])
target = target * mask
a = torch.sum(input * target, 1) #a.shape torch.Size([8])
b = torch.sum(input * input, 1) + 0.001
c = torch.sum(target * target, 1) + 0.001
d = (2 * a) / (b + c)
dice_loss = torch.mean(d)
return 1 - dice_loss
def cal_text_score(texts, gt_texts, training_masks, running_metric_text):
training_masks = training_masks.data.cpu().numpy()
pred_text = torch.sigmoid(texts).data.cpu().numpy() * training_masks
pred_text[pred_text <= 0.5] = 0
pred_text[pred_text > 0.5] = 1
pred_text = pred_text.astype(np.int32)
gt_text = gt_texts.data.cpu().numpy() * training_masks
gt_text = gt_text.astype(np.int32)
running_metric_text.update(gt_text, pred_text)
score_text, _ = running_metric_text.get_scores()
return score_text
def cal_kernel_score(kernels, gt_kernels, gt_texts, training_masks, running_metric_kernel):
mask = (gt_texts * training_masks).data.cpu().numpy()
kernel = kernels[:, -1, :, :]
gt_kernel = gt_kernels[:, -1, :, :]
pred_kernel = torch.sigmoid(kernel).data.cpu().numpy()
pred_kernel[pred_kernel <= 0.5] = 0
pred_kernel[pred_kernel > 0.5] = 1
pred_kernel = (pred_kernel * mask).astype(np.int32)
gt_kernel = gt_kernel.data.cpu().numpy()
gt_kernel = (gt_kernel * mask).astype(np.int32)
running_metric_kernel.update(gt_kernel, pred_kernel)
score_kernel, _ = running_metric_kernel.get_scores()
return score_kernel
def train(train_loader, model, criterion, optimizer, epoch):
model.train()
#AverageMeter类的作用:Computes and stores the average and current value
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
running_metric_text = runningScore(2) #runningScore()是一个计算overall accuracy/mean accuracy/mean IU/fwavacc的类, 2代表类别数
running_metric_kernel = runningScore(2)
end = time.time()
for batch_idx, (imgs, gt_texts, gt_kernels, training_masks) in enumerate(train_loader):
data_time.update(time.time() - end)
imgs = Variable(imgs.cuda())
gt_texts = Variable(gt_texts.cuda())
gt_kernels = Variable(gt_kernels.cuda())
training_masks = Variable(training_masks.cuda())
outputs = model(imgs) #output的格式为:torch.Size([8, 7, 640, 640]),batch是8,7个kerneals的640*640大小的图片,这个model的网络是fpn
texts = outputs[:, 0, :, :] #texts:torch.Size([8, 640, 640])
kernels = outputs[:, 1:, :, :] #kernels:torch.Size([8, 6, 640, 640]) 每个文本实例都有多个预测,对应多个不同尺度的kernels
selected_masks = ohem_batch(texts, gt_texts, training_masks) # gt_texts和training_mask都是:torch.Size([8, 640, 640])
selected_masks = Variable(selected_masks.cuda()) # selected_masks:torch.Size([8, 640, 640])
loss_text = criterion(texts, gt_texts, selected_masks) #文本分割的损失函数
loss_kernels = []
mask0 = torch.sigmoid(texts).data.cpu().numpy()
mask1 = training_masks.data.cpu().numpy()
selected_masks = ((mask0 > 0.5) & (mask1 > 0.5)).astype('float32')
selected_masks = torch.from_numpy(selected_masks).float()
selected_masks = Variable(selected_masks.cuda())
for i in range(kernel_num-1):
kernel_i = kernels[:, i, :, :]
gt_kernel_i = gt_kernels[:, i, :, :]
loss_kernel_i = criterion(kernel_i, gt_kernel_i, selected_masks)
loss_kernels.append(loss_kernel_i)
loss_kernel = sum(loss_kernels) / len(loss_kernels) #kernal的损失函数
loss = 0.7 * loss_text + 0.3 * loss_kernel
losses.update(loss.item(), imgs.size(0))
optimizer.zero_grad()
loss.backward()
optimizer.step()
score_text = cal_text_score(texts, gt_texts, training_masks, running_metric_text)
score_kernel = cal_kernel_score(kernels, gt_kernels, gt_texts, training_masks, running_metric_kernel)
batch_time.update(time.time() - end)
end = time.time()
if batch_idx % 20 == 0:
output_log = '({batch}/{size}) Batch: {bt:.3f}s | TOTAL: {total:.0f}min | ETA: {eta:.0f}min | Loss: {loss:.4f} | Acc_t: {acc: .4f} | IOU_t: {iou_t: .4f} | IOU_k: {iou_k: .4f}'.format(
batch=batch_idx + 1,
size=len(train_loader),
bt=batch_time.avg,
total=batch_time.avg * batch_idx / 60.0,
eta=batch_time.avg * (len(train_loader) - batch_idx) / 60.0, #估计完成处理所需的时间
loss=losses.avg, #text和kernel的总loss均值
acc=score_text['Mean Acc'],
iou_t=score_text['Mean IoU'],
iou_k=score_kernel['Mean IoU'])
print(output_log)
sys.stdout.flush()
# train_loss, train_te_acc, train_ke_acc, train_te_iou, train_ke_iou
return (losses.avg, score_text['Mean Acc'], score_kernel['Mean Acc'], score_text['Mean IoU'], score_kernel['Mean IoU'])
def adjust_learning_rate(args, optimizer, epoch):
global state
if epoch in args.schedule:
args.lr = args.lr * 0.1
for param_group in optimizer.param_groups:
param_group['lr'] = args.lr
def save_checkpoint(state, checkpoint='checkpoint', filename='checkpoint1.pth.tar'):
filepath = os.path.join(checkpoint, filename)
torch.save(state, filepath)
def set_seed(seed):
import numpy as np
import random
import torch
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
GLOBAL_WORKER_ID = None
GLOBAL_SEED = 1000
def worker_init_fn(worker_id):
global GLOBAL_WORKER_ID
GLOBAL_WORKER_ID = worker_id
set_seed(GLOBAL_SEED + worker_id)
def main(args):
if args.checkpoint == '':
args.checkpoint = "checkpoints/ic15_%s_bs_%d_ep_%d"%(args.arch, args.batch_size, args.n_epoch)
if args.pretrain:
if 'synth' in args.pretrain:
args.checkpoint += "_pretrain_synth"
else:
args.checkpoint += "_pretrain_ic17"
print ('checkpoint path: %s'%args.checkpoint)
print ('init lr: %.8f'%args.lr)
print ('schedule: ', args.schedule)
sys.stdout.flush() #实时将缓冲区的内容输出
if not os.path.isdir(args.checkpoint):
os.makedirs(args.checkpoint)
kernel_num = 7
min_scale = 0.4
start_epoch = 0
data_loader = IC15Loader(is_transform=True, img_size=args.img_size, kernel_num=kernel_num, min_scale=min_scale)
#DataLoader数据加载器,结合了数据集和取样器,并且可以提供多个线程处理数据集。
train_loader = torch.utils.data.DataLoader(
data_loader,
batch_size=args.batch_size,
shuffle=True,
num_workers=7,
worker_init_fn=worker_init_fn,
drop_last=True, #这个是对最后的未完成的batch来说的,比如你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被扔掉了…
pin_memory=True) #data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.这样将内存的Tensor转义到GPU的显存就会更快一些
if args.arch == "resnet18":
model = models.resnet18_PAN(pretrained=True,add_ori=False)
elif args.arch == "resnet34":
model = models.resnet34_PAN(pretrained=True,add_ori=False)
elif args.arch == "resnet50_PAN":
model = models.resnet50_PAN(pretrained=True,add_ori=False,backbone='big')
elif args.arch == "resnet50_common":
model = models.resnet50(pretrained=True, num_classes=kernel_num)
elif args.arch == "resnet101":
model = models.resnet101(pretrained=True, num_classes=kernel_num)
elif args.arch == "resnet152":
model = models.resnet152(pretrained=True, num_classes=kernel_num)
elif args.arch == "sf_1":
model = models.sf_1(pretrained=True)
elif args.arch == "sf_2":
model = models.sf_2(pretrained=True)
elif args.arch == "bisenet_cat_18":
model = models.BiSeNet(7, 'resnet18')
elif args.arch == "bisenet_cat_34":
model = models.BiSeNet(7, 'resnet34')
elif args.arch == "bisenet_cat_50":
model = models.BiSeNet(7, 'resnet50')
elif args.arch == "bisenet_cat_101":
model = models.BiSeNet(7, 'resnet101')
elif args.arch == "bisenet_cat_152":
model = models.BiSeNet(7, 'resnet152')
model = torch.nn.DataParallel(model).cuda()
if hasattr(model.module, 'optimizer'):
optimizer = model.module.optimizer
else:
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.99, weight_decay=5e-4)
title = 'resnet50_common'
if args.pretrain:
print('Using pretrained model.')
assert os.path.isfile(args.pretrain), 'Error: no checkpoint directory found!'
checkpoint = torch.load(args.pretrain)
model.load_state_dict(checkpoint['state_dict'])
logger = Logger(os.path.join(args.checkpoint, 'log1.txt'), title=title)
logger.set_names(['Learning Rate', 'Train Loss','Train Acc.', 'Train IOU.'])
elif args.resume:
print('Resuming from checkpoint.')
assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!'
checkpoint = torch.load(args.resume)
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
logger = Logger(os.path.join(args.checkpoint, 'log1.txt'), title=title, resume=True)
else:
print('Training from scratch.')
logger = Logger(os.path.join(args.checkpoint, 'log1.txt'), title=title)
logger.set_names(['Learning Rate', 'Train Loss','Train Acc.', 'Train IOU.','recall','precision','f1'])
f1_ori = 0
for epoch in range(start_epoch, args.n_epoch):
adjust_learning_rate(args, optimizer, epoch)
print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.n_epoch, optimizer.param_groups[0]['lr']))
train_loss, train_te_acc, train_ke_acc, train_te_iou, train_ke_iou = train(train_loader, model, dice_loss, optimizer, epoch)
recall, precision, f1 = 0,0,0
save_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
'lr': args.lr,
'optimizer': optimizer.state_dict(),
}, checkpoint=args.checkpoint)
logger.append([optimizer.param_groups[0]['lr'], train_loss, train_te_acc, train_te_iou,recall, precision, f1])
logger.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Hyperparams')
parser.add_argument('--arch', nargs='?', type=str, default='resnet50_common') #backbone
parser.add_argument('--img_size', nargs='?', type=int, default=640,
help='Height of the input image')
parser.add_argument('--n_epoch', nargs='?', type=int, default=600,
help='# of the epochs')
parser.add_argument('--schedule', type=int, nargs='+', default=[200,400,550],
help='Decrease learning rate at these epochs.')
parser.add_argument('--batch_size', nargs='?', type=int, default=8,
help='Batch Size')
parser.add_argument('--lr', nargs='?', type=float, default=1e-3,
help='Learning Rate')
parser.add_argument('--resume', nargs='?', type=str, default='',
help='Path to previous saved model to restart from')
parser.add_argument('--pretrain', nargs='?', type=str, default=None,
help='Path to previous saved model to restart from')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to save checkpoint (default: checkpoint)')
args = parser.parse_args()
main(args)