以基于tusimple数据集的模型训练为例子,首先数据集的预处理
每当对一张图片进行读取,会截取原图和标签的下大半部分,然后进行数据增强(包括随机旋转、随机水平移动、改变指定的尺寸和归一化)
def __getitem__(self, idx):
img = cv2.imread(self.full_img_path_list[idx]).astype(np.float32)
img = img[self.cfg.cut_height:, :, :] # 裁剪图片的一部分作为训练数据集
if self.is_training:
label = cv2.imread(self.label_list[idx], cv2.IMREAD_UNCHANGED)
if len(label.shape) > 2:
label = label[:, :, 0]
# print(label.shape) # (720, 1280)
label = label.squeeze()
# print(label.shape) # (720, 1280)/.
label = label[self.cfg.cut_height:, :]
exist = self.exist_list[idx]
if self.transform:
img, label = self.transform((img, label))
label = torch.from_numpy(label).contiguous().long() # 标签的数据格式应该为long()
else:
img, = self.transform((img,))
img = torch.from_numpy(img).permute(2, 0, 1).contiguous().float() # 原图的数据格式为float()
meta = {'full_img_path': self.full_img_path_list[idx],
'img_name': self.img_name_list[idx]}
data = {'img': img, 'meta': meta}
if self.is_training:
data.update({'label': label, 'exist': exist})
return data
def transform_train(self):
input_mean = self.cfg.img_norm['mean']
train_transform = torchvision.transforms.Compose([
tf.GroupRandomRotation(), # 随机旋转
tf.GroupRandomHorizontalFlip(), #随机水平偏移
tf.SampleResize((self.cfg.img_width, self.cfg.img_height)), #修改为指定尺寸
tf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=( # 归一化
self.cfg.img_norm['std'], (1, ))),
])
return train_transform
网络模型的代码
import torch.nn as nn
import torch
import torch.nn.functional as F
from models.registry import NET
from .resnet import ResNetWrapper
from .decoder import BUSD, PlainDecoder
class RESA(nn.Module):
def __init__(self, cfg):
super(RESA, self).__init__()
self.iter = cfg.resa.iter
chan = cfg.resa.input_channel
fea_stride = cfg.backbone.fea_stride
self.height = cfg.img_height // fea_stride
self.width = cfg.img_width // fea_stride
self.alpha = cfg.resa.alpha
conv_stride = cfg.resa.conv_stride
for i in range(self.iter):
conv_vert1 = nn.Conv2d(
chan, chan, (1, conv_stride),
padding=(0, conv_stride//2), groups=1, bias=False)
conv_vert2 = nn.Conv2d(
chan, chan, (1, conv_stride),
padding=(0, conv_stride//2), groups=1, bias=False)
setattr(self, 'conv_d'+str(i), conv_vert1)
setattr(self, 'conv_u'+str(i), conv_vert2)
conv_hori1 = nn.Conv2d(
chan, chan, (conv_stride, 1),
padding=(conv_stride//2, 0), groups=1, bias=False)
conv_hori2 = nn.Conv2d(
chan, chan, (conv_stride, 1),
padding=(conv_stride//2, 0), groups=1, bias=False)
setattr(self, 'conv_r'+str(i), conv_hori1)
setattr(self, 'conv_l'+str(i), conv_hori2)
idx_d = (torch.arange(self.height) + self.height //
2**(self.iter - i)) % self.height
setattr(self, 'idx_d'+str(i), idx_d)
idx_u = (torch.arange(self.height) - self.height //
2**(self.iter - i)) % self.height
setattr(self, 'idx_u'+str(i), idx_u)
idx_r = (torch.arange(self.width) + self.width //
2**(self.iter - i)) % self.width
setattr(self, 'idx_r'+str(i), idx_r)
idx_l = (torch.arange(self.width) - self.width //
2**(self.iter - i)) % self.width
setattr(self, 'idx_l'+str(i), idx_l)
def forward(self, x):
x = x.clone()
for direction in ['d', 'u']:
for i in range(self.iter):
conv = getattr(self, 'conv_' + direction + str(i))
idx = getattr(self, 'idx_' + direction + str(i))
x.add_(self.alpha * F.relu(conv(x[..., idx, :])))
for direction in ['r', 'l']:
for i in range(self.iter):
conv = getattr(self, 'conv_' + direction + str(i))
idx = getattr(self, 'idx_' + direction + str(i))
x.add_(self.alpha * F.relu(conv(x[..., idx])))
return x
class ExistHead(nn.Module):
def __init__(self, cfg=None):
super(ExistHead, self).__init__()
self.cfg = cfg
self.dropout = nn.Dropout2d(0.1) # ???
self.conv8 = nn.Conv2d(128, cfg.num_classes, 1)
stride = cfg.backbone.fea_stride * 2
self.fc9 = nn.Linear(
int(cfg.num_classes * cfg.img_width / stride * cfg.img_height / stride), 128)
self.fc10 = nn.Linear(128, cfg.num_classes-1)
def forward(self, x):
x = self.dropout(x)
x = self.conv8(x)
x = F.softmax(x, dim=1)
x = F.avg_pool2d(x, 2, stride=2, padding=0)
x = x.view(-1, x.numel() // x.shape[0])
x = self.fc9(x)
x = F.relu(x)
x = self.fc10(x)
x = torch.sigmoid(x)
return x
@NET.register_module
class RESANet(nn.Module):
def __init__(self, cfg):
super(RESANet, self).__init__()
self.cfg = cfg
self.backbone = ResNetWrapper(cfg)
self.resa = RESA(cfg)
self.decoder = eval(cfg.decoder)(cfg)
self.heads = ExistHead(cfg)
def forward(self, batch):
fea = self.backbone(batch)
fea = self.resa(fea)
seg = self.decoder(fea)
exist = self.heads(fea)
output = {'seg': seg, 'exist': exist}
return output
损失函数
可用交叉熵损失函数或者dice_loss损失函数
import torch.nn as nn
import torch
import torch.nn.functional as F
from runner.registry import TRAINER
def dice_loss(input, target):
input = input.contiguous().view(input.size()[0], -1) # 改变input和target的矩阵格式
target = target.contiguous().view(target.size()[0], -1).float()
a = torch.sum(input * target, 1)
b = torch.sum(input * input, 1) + 0.001 # 避免|X|和|Y|都为0时,分母为零;同时减少过拟合的可能性
c = torch.sum(target * target, 1) + 0.001
d = (2 * a) / (b + c)
return (1-d).mean()
@TRAINER.register_module
class RESA(nn.Module):
def __init__(self, cfg):
super(RESA, self).__init__()
self.cfg = cfg
self.loss_type = cfg.loss_type
if self.loss_type == 'cross_entropy':
weights = torch.ones(cfg.num_classes)
weights[0] = cfg.bg_weight
weights = weights.cuda()
self.criterion = torch.nn.NLLLoss(ignore_index=self.cfg.ignore_label,
weight=weights).cuda()
self.criterion_exist = torch.nn.BCEWithLogitsLoss().cuda()
def forward(self, net, batch):
output = net(batch['img'])
loss_stats = {}
loss = 0.
if self.loss_type == 'dice_loss':
target = F.one_hot(batch['label'], num_classes=self.cfg.num_classes).permute(0, 3, 1, 2)
seg_loss = dice_loss(F.softmax(
output['seg'], dim=1)[:, 1:], target[:, 1:])
else:
seg_loss = self.criterion(F.log_softmax(
output['seg'], dim=1), batch['label'].long())
loss += seg_loss * self.cfg.seg_loss_weight
loss_stats.update({'seg_loss': seg_loss})
if 'exist' in output:
exist_loss = 0.1 * \
self.criterion_exist(output['exist'], batch['exist'].float())
loss += exist_loss
loss_stats.update({'exist_loss': exist_loss})
ret = {'loss': loss, 'loss_stats': loss_stats}
return ret
开始训练,运行以下代码
python main.py configs/tusimple.py --gpus 0
用自己的视频来测试训练好的模型的性能
import os
import os.path as osp
import time
import shutil
import torch
import torchvision
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torch.optim
import cv2
import numpy as np
import models
import argparse
from utils.config import Config
from runner.runner import Runner
from datasets import build_dataloader
color_list =[
(255, 0, 0),
(255, 225, 0),
(255, 0, 255),
(125, 125, 125),
(255, 125, 125),
(0, 125, 0)
]
def main():
args = parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(gpu) for gpu in args.gpus)
cfg = Config.fromfile(args.config)
cfg.gpus = len(args.gpus)
cfg.load_from = args.load_from
cfg.finetune_from = args.finetune_from
cfg.view = args.view
cfg.work_dirs = args.work_dirs + '/' + cfg.dataset.train.type
cudnn.benchmark = True
cudnn.fastest = True
runner = Runner(cfg)
runner.net.eval()
val_loader = build_dataloader(cfg.dataset.val, cfg, is_train=False)
def to_cuda(batch):
for k in batch:
if k == 'meta':
continue
batch[k] = batch[k].cuda()
return batch
def is_short(lane):
start = [i for i, x in enumerate(lane) if x > 0]
if not start:
return 1
else:
return 0
def probmap2lane( seg_pred, exist, b, resize_shape=(720, 1280), smooth=True, y_px_gap=10, pts=56, thresh=0.6):
"""
Arguments:
----------
seg_pred: np.array size (5, h, w)
resize_shape: reshape size target, (H, W)
exist: list of existence, e.g. [0, 1, 1, 0]
smooth: whether to smooth the probability or not
y_px_gap: y pixel gap for sampling
pts: how many points for one lane
thresh: probability threshold
Return:
----------
coordinates: [x, y] list of lanes, e.g.: [ [[9, 569], [50, 549]] ,[[630, 569], [647, 549]] ]
"""
if resize_shape is None:
resize_shape = seg_pred.shape[1:] # seg_pred (5, h, w)
_, h, w = seg_pred.shape
H, W = resize_shape
coordinates = []
a = 0
for i in range(cfg.num_classes - 1):
prob_map = seg_pred[i + 1] # seg_pred[0]:背景
if smooth:
prob_map = cv2.blur(prob_map, (9, 9), borderType=cv2.BORDER_REPLICATE)
coords = get_lane(prob_map, y_px_gap, pts, thresh, resize_shape)
# print(exist)
# if (int)(b[i]) == 0: # if (int)(exist[i])==0:
# continue
if is_short(coords):
continue
coordinates.append(
[[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
range(pts)])
# if (int)(exist[i])==1:
# a =a+1
# if a==2:
# break
if len(coordinates) == 0:
coords = np.zeros(pts)
coordinates.append(
[[coords[j], H - 10 - j * y_px_gap] if coords[j] > 0 else [-1, H - 10 - j * y_px_gap] for j in
range(pts)])
# print(coordinates)
return coordinates
def fix_gap(coordinate):
if any(x > 0 for x in coordinate):
start = [i for i, x in enumerate(coordinate) if x > 0][0]
end = [i for i, x in reversed(list(enumerate(coordinate))) if x > 0][0]
lane = coordinate[start:end+1]
if any(x < 0 for x in lane):
gap_start = [i for i, x in enumerate(
lane[:-1]) if x > 0 and lane[i+1] < 0]
gap_end = [i+1 for i,
x in enumerate(lane[:-1]) if x < 0 and lane[i+1] > 0]
gap_id = [i for i, x in enumerate(lane) if x < 0]
if len(gap_start) == 0 or len(gap_end) == 0:
return coordinate
for id in gap_id:
for i in range(len(gap_start)):
if i >= len(gap_end):
return coordinate
if id > gap_start[i] and id < gap_end[i]:
gap_width = float(gap_end[i] - gap_start[i])
lane[id] = int((id - gap_start[i]) / gap_width * lane[gap_end[i]] + (
gap_end[i] - id) / gap_width * lane[gap_start[i]])
if not all(x > 0 for x in lane):
print("Gaps still exist!")
coordinate[start:end+1] = lane
return coordinate
def get_lane(prob_map, y_px_gap, pts, thresh, resize_shape=None):
"""
Arguments:
----------
prob_map: prob map for single lane, np array size (h, w)
resize_shape: reshape size target, (H, W)
Return:
----------
coords: x coords bottom up every y_px_gap px, 0 for non-exist, in resized shape
"""
if resize_shape is None:
resize_shape = prob_map.shape
h, w = prob_map.shape
H, W = resize_shape
H -= cfg.cut_height
coords = np.zeros(pts)
coords[:] = -1.0
for i in range(pts):
y = int((H - 10 - i * y_px_gap) * h / H)
if y < 0:
break
line = prob_map[y, :]
id = np.argmax(line)
if line[id] > thresh:
coords[i] = int(id / w * W)
if (coords > 0).sum() < 2:
coords = np.zeros(pts)
fix_gap(coords)
# print(coords.shape)
return coords
def view(img, coords, file_path=None):
i=0
for coord in coords:
for x, y in coord:
if x <= 0 or y <= 0:
continue
x, y = int(x), int(y)
cv2.circle(img, (x, y), 4, color_list[i], 2)
i = i+1
# if file_path is not None:
# if not os.path.exists(osp.dirname(file_path)):
# os.makedirs(osp.dirname(file_path))
# cv2.imwrite(file_path, img)
import time
time_start = time.clock()
fps = 0.0
capture = cv2.VideoCapture("/media/gooddz/新加卷/检测视频/极弯场景.mp4")
import torchvision
import utils.transforms as tf
def transform_val():
val_transform = torchvision.transforms.Compose([
tf.SampleResize((640, 368)),
tf.GroupNormalize(mean=([103.939, 116.779, 123.68], (0, )), std=(
[1., 1., 1.], (1, ))),
])
return val_transform
while (True):
t1 = time.time()
ref,frame = capture.read()
# img_test1 = cv.resize(img, (int(y / 2), int(x / 2)))
frame = cv2.resize(frame,(1280,720))
frame_copy = frame.copy()
frame = frame[160:, :, :]
# print(type(frame))
# frame = frame[None,:]
# val_transform = transforms.Compose([
# tf.SampleResize((640, 368)),
# tf.GroupNormalize(mean=([103.939, 116.779, 123.68], (0,)), std=(
# [1., 1., 1.], (1,))),
# ])
# print(frame.shape)
transform = transform_val()
frame = transform((frame,))
# print(frame, "zzz")
# print(frame[0].shape)
frame = torch.from_numpy(frame[0]).permute(2, 0, 1).contiguous().float()
frame = torch.tensor(frame)
# print(frame.shape)
frame = frame.unsqueeze(0).float()
frame = frame.cuda()
with torch.no_grad():
# print(data['img'])
output = runner.net(frame)
# print(output)
seg_pred, exist_pred = output['seg'], output['exist']
# a = output['exist_lane']
# _, b_1 = torch.max(F.softmax(a, dim=2), 2)
# print(F.softmax(a, dim=1),b)
# a = F.softmax(a, dim=0)
# print(b,b.shape)
# s = torch.argmax(seg_pred[0],0)
# s = s.detach().cpu().numpy()
# dst_binary_image = np.zeros([s.shape[0], s.shape[1]], np.uint8)
# for y in range(s.shape[0]):
# for x in range(s.shape[1]):
# dst_binary_image[y,x] = (s[y,x]*40)
# cv2.imshow("zz",dst_binary_image)
# cv2.waitKey(5)
seg_pred = F.softmax(seg_pred, dim=1)
seg_pred = seg_pred.detach().cpu().numpy()
exist_pred = exist_pred.detach().cpu().numpy()
# print(b, b.shape, exist_pred, exist_pred.shape)
for b in range(len(seg_pred)):
seg = seg_pred[b]
# print(len(seg_pred))
exist_1 = [1 if exist_pred[b, i] >
0.5 else 0 for i in range(cfg.num_classes - 1)]
lane_coords = probmap2lane(seg, exist_1, thresh=0.6, b=exist_1[b])
# print(lane_coords)
for i in range(len(lane_coords)):
lane_coords[i] = sorted(
lane_coords[i], key=lambda pair: pair[1])
# frame = np.array(frame)
# print(lane_coords)
# print(frame_copy.shape, type(frame_copy))
view(frame_copy, lane_coords)
# frame = frame[0].permute([1, 2, 0])
# (720, 1280, 3)
# print(frame.shape)
fps = (fps + (1. / (time.time() - t1))) / 2
# print(frame[0].shape,frame)
# frame_copy = frame_copy.astype(np.uint8)
# cv2.namedWindow('imshow', cv2.WINDOW_NORMAL)
cv2.imshow('imshow', frame_copy)
cv2.waitKey(1)
print("fps:", fps)
cv2.destroyAllWindows()
time_end = time.clock()
print(time_end-time_start)
def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument('config', help='train config file path')
parser.add_argument(
'--work_dirs', type=str, default='work_dirs',
help='work dirs')
parser.add_argument(
'--load_from', default='/home/llgj/桌面/ldz/resa-main_原/work_dirs/TuSimple/20220120_083126_lr_2e-02_b_4/ckpt/best.pth')
parser.add_argument(
'--finetune_from', default=None,
help='whether to finetune from the checkpoint')
parser.add_argument(
'--validate',
action='store_true',
help='whether to evaluate the checkpoint during training')
parser.add_argument(
'--view',
action='store_true',
help='whether to show visualization result')
parser.add_argument('--gpus', nargs='+', type=int, default='0')
parser.add_argument('--seed', type=int,
default=None, help='random seed')
args = parser.parse_args()
return args
if __name__ == '__main__':
main()
#configs/tusimple.py --gpus 0
#configs/tusimple.py --validate --load_from /media/gooddz/学习/culane_resnet50.pth --gpus 0 --view
新建test.py,复制上面的代码,运行以下代码进行测试
python test.py configs/tusimple.py --validate --load_from /media/gooddz/学习/tusimple.pth --gpus 0 --view