SCNN、Line-CNN、FastDraw
https://github.com/lucastabelini/PolyLaneNet
https://arxiv.org/pdf/2004.10924.pdf
模型预测输出
前面4个a表示多项式系数
the vertical position h of the horizon line,
which helps to define the upper limit of the lane markings
c车道线得分,每条车道线都有一个。
论文核心
target-label
类别、s、h、车道线上点坐标(xs,ys)
损失代码
# classification loss
if self.pred_category and self.extra_outputs > 0:
ce = nn.CrossEntropyLoss()
pred_categories = extra_outputs.reshape(target.shape[0] * target.shape[1], -1)
target_categories = target_categories.reshape(pred_categories.shape[:-1]).long()
# 只会计算target_categories>0的线
pred_categories = pred_categories[target_categories > 0]
target_categories = target_categories[target_categories > 0]
cls_loss = ce(pred_categories, target_categories - 1)
else:
cls_loss = 0
# poly loss calc
target_xs = target_points[valid_lanes_idx_flat, :target_points.shape[1] // 2]
ys = target_points[valid_lanes_idx_flat, target_points.shape[1] // 2:].t()
valid_xs = target_xs >= 0
pred_polys = pred_polys[valid_lanes_idx_flat]
#利用pre多项式系数,计算坐标
pred_xs = pred_polys[:, 0] * ys**3 + pred_polys[:, 1] * ys**2 + pred_polys[:, 2] * ys + pred_polys[:, 3]
pred_xs.t_()
weights = (torch.sum(valid_xs, dtype=torch.float32) / torch.sum(valid_xs, dim=1, dtype=torch.float32))**0.5
pred_xs = (pred_xs.t_() *
weights).t() # without this, lanes with more points would have more weight on the cost function
target_xs = (target_xs.t_() * weights).t()
poly_loss = mse(pred_xs[valid_xs], target_xs[valid_xs]) / valid_lanes_idx.sum()
poly_loss = threshold(
(pred_xs[valid_xs] - target_xs[valid_xs])**2).sum() / (valid_lanes_idx.sum() * valid_xs.sum())
将预测的多项式系数和target的一个坐标ys相乘,得到预测坐标pred_xs
计算pred_xs和target_xs的mse损失