MMSeg——Mutli-view时序数据检查与可视化

可视化功能函数

首先我们需要构造一个能够用于可视化的功能函数,我们将其写在tools/visualize_helper.py下,方便在其他函数中调用:

visualize_helper.py

import os
import cv2
import torch
import numpy as np

# visualize
label_colors = np.array([
        [255, 255, 255],
        [0, 0, 255],
        [0, 255, 0],
        [255, 0, 0],
        [0, 255, 255],
        ])

def decode_segmap(mask):
    rgb_mask_list = [mask.copy() for i in range(3)]
    rgb = np.ones((mask.shape[0], mask.shape[1], 3)) # create an empty rgb image to save clorized masks
    for idx, single_mask in enumerate(rgb_mask_list):
        for idx_c, color in enumerate(label_colors):
            rgb_mask_list[idx][single_mask == idx_c] = color[idx] # colorize pixels if the value is equal to the class num
        rgb[:, :, idx] = rgb_mask_list[idx] # rgb = [r, g, b]
    return rgb.astype(np.int)
    
def post_process(mask):
    '''
    mask: [W, H]
    gt_mask: [W, H, 3]
    '''
    gt_mask = decode_segmap(mask)
    return gt_mask
    
    

def get_multi_view_imgs(multi_view_file_list, dataset_dir):
    '''
    multi_view_imgs (Tensor): input images, shape is [B,6,3,H,W]
    '''
    img_list = []
    for i, img_path in enumerate(multi_view_file_list):
        img = cv2.imread(os.path.join(dataset_dir, img_path))
        img_list.append(img.astype('float32'))

    num_temporal = int(len(img_list)/6)
    
    surr_img_list = []
    for idx in range(num_temporal):
        surr_img_top = cv2.hconcat(img_list[0+idx*6:3+idx*6])    # 水平拼接
        surr_img_btm = cv2.hconcat(img_list[3+idx*6:6+idx*6])    # 水平拼接
        surr_img = cv2.vconcat([surr_img_top, surr_img_btm])
        surr_img_list.append(surr_img)
    
    surr_imgs = cv2.hconcat(surr_img_list)
        
    return surr_imgs, num_temporal
    
    
def get_gt_imgs(gt):
    '''
    gt (Tensor): input gt, shape is [B,H,W]
    '''
    gt = gt.squeeze(0).squeeze(0)
    gt = gt.cpu().numpy()
    gt_img = post_process(gt)
    return gt_img.astype('float32')
    
def save_train_imgs(img_metas, gt, save_name="multi_view_imgs"):
    '''
    img_metas (list[dict]): List of image info dict
    gt (Tensor): input gt, shape is [B,H,W]
    '''
    for idx, img_meta in enumerate(img_metas):
        dataset_dir = "/".join(img_meta['filename'].split("/")[:-3])
        file_list = img_meta['ori_filename']
        surr_img, num_temporal = get_multi_view_imgs(file_list, dataset_dir)
        gt_img = get_gt_imgs(gt)
        
        surr_img = cv2.resize(surr_img ,(1600*num_temporal, 600))
        gt_img = cv2.resize(gt_img, (300, 600))
        
        print(surr_img.shape)
        print(gt_img.shape)
        all_img = cv2.hconcat([surr_img, gt_img])    # 水平拼接
        cv2.imwrite(f"{save_name}_bs{idx}.jpg", all_img)

在训练过程中进行可视化

forward_train函数中加入函数即可实现功能:

	def forward_train(self, img, img_metas, gt_semantic_seg):
		#####################
		# 训练过程中数据检查与可视化
		from tools.visualize_helper import save_train_imgs
        save_train_imgs(img_metas, gt_semantic_seg)
		###################
        x = self.extract_feat(img)

        losses = dict()
        loss_decode = self._decode_head_forward_train(x, img_metas,
                                                      gt_semantic_seg)
        losses.update(loss_decode)
        if self.with_auxiliary_head:
            loss_aux = self._auxiliary_head_forward_train(
                x, img_metas, gt_semantic_seg)
            losses.update(loss_aux)
        return losses

可视化效果

  • 左边为前一时刻的6张环视图像,右边为当前时刻的6张环视图像,最右边为BEV的GT。
    在这里插入图片描述
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值