PVEN车辆重识别Pytorch代码解读--数据加载篇(一)

数据加载部分–base.py

PVEN车辆重识别
论文地址:https://arxiv.org/abs/2004.05021
Pytorch代码地址:https://github.com/silverbulletmdc/PVEN

get_imagedata_info函数

输入:data
return:num_pids(车辆id数目), num_imgs(img数量), num_cams(cam数目)
def get_imagedata_info(data):#获取图像信息
    ids, cams = [], []#ids和cam分别存放车辆的id和cam信息
    for item in data:
        ids.append(item["id"])
        cams.append(item["cam"])

    ids = [item["id"] for item in data]
    cams = [item["cam"] for item in data]
    pids = set(ids)#set集合可用于去重复
    cams = set(cams)
    num_pids = len(pids)#车辆id数目
    num_cams = len(cams)#cam数目
    num_imgs = len(data)#图像数目
    return num_pids, num_imgs, num_cams

relabel函数

输入:data
return:data(更新后的data), rawid2label(车辆id对应的序号字典), label2rawid(序号对应的车辆id字典)
def relabel(data):
    """
    :param list data:
    :return:
    """
    raw_ids = set()
    data = data.copy()
    for item in data:
        raw_ids.add(item['id'])
    raw_ids = sorted(list(raw_ids))#将所有的id进行排序
    rawid2label = {raw_vid: i for i, raw_vid in enumerate(raw_ids)}#一个id对应一个序号
    label2rawid = {i: raw_vid for i, raw_vid in enumerate(raw_ids)}#一个序号对应一个id
    for item in data:
        item["id"] = rawid2label[item["id"]]#将data中的车辆id换为对应的序号
        item["cam"] = int(item["cam"])#将data中的cam转换为整型
    return data, rawid2label, label2rawid

ReIDMetaDataset类

类的属性:
        self.train = metas["train"]#训练集标签
        self.query = metas["query"]#查询集标签
        self.gallery = metas["gallery"]#gallery集合标签
        self.relabel()#标签更新
        self._calc_meta_info()#统计各个集合信息
         if verbose:#用于输出数据集的统计信息
                 print("=> Dataset loaded")
                 self.print_dataset_statistics()#输出数据集的统计信息

class ReIDMetaDataset:
    """
    定义了ReID数据集的元信息。必须包含train, query, gallery属性。
    A list of dict. Dict contains meta infomation, which is
    {
        "image_path": str, required
        "id": int, required

        "cam"(optional): int,
        "keypoints"(optional): extra information
        "kp_vis"(optional): 每个keypoint是否可见
        "mask"(optional): extra information
        "box"(optional): extra information
        "color"(optional): extra information
        "type"(optional): extra information
        "view"(optional): extra information
    }
    """
    def __init__(self, pkl_path, verbose=True, **kwargs):
        with open(pkl_path, 'rb') as f:
            metas = pkl.load(f)

        self.train = metas["train"]
        self.query = metas["query"]
        self.gallery = metas["gallery"]
        self.relabel()
        self._calc_meta_info()

        if verbose:#用于输出数据集的统计信息
            print("=> Dataset loaded")
            self.print_dataset_statistics()

    def relabel(self):#将所有的数据进行relabel
        self.train, self.train_rawid2label, self.train_label2rawid = relabel(self.train)#训练集中的id和cam进行更新
        eval_set, self.eval_rawid2label, self.eval_label2rawid = relabel(self.query + self.gallery)#测试集中的id和cam进行更新
        self.query = eval_set[:len(self.query)]#query部分
        self.gallery = eval_set[len(self.query):]#gallery部分

    def print_dataset_statistics(self):#输出数据集信息
        num_train_pids, num_train_imgs, num_train_cams = get_imagedata_info(self.train)#获取训练集的信息
        num_query_pids, num_query_imgs, num_query_cams = get_imagedata_info(self.query)#获取查询集的信息
        num_gallery_pids, num_gallery_imgs, num_gallery_cams = get_imagedata_info(self.gallery)##获取gallery数据集的信息
        #输出信息
        print("Dataset statistics:")
        print("  ----------------------------------------")
        print("  subset   | # ids | # images | # cameras")
        print("  ----------------------------------------")
        print("  train    | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams))
        print("  query    | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams))
        print("  gallery  | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams))
        print("  ----------------------------------------")

    def _calc_meta_info(self):#统计元信息
        self.num_train_ids, self.num_train_imgs, self.num_train_cams = get_imagedata_info(self.train)
        self.num_query_ids, self.num_query_imgs, self.num_query_cams = get_imagedata_info(self.query)
        self.num_gallery_ids, self.num_gallery_imgs, self.num_gallery_cams = get_imagedata_info(self.gallery)

ReIDDataset类

继承Dataset类,用来加载数据,需要定义__init__,__getitem__和__len__方法
class ReIDDataset(Dataset):
    def __init__(self, meta_dataset, *, with_mask=False, mask_num=5, transform=None, preprocessing=None):
        """将元数据集转化为图片数据集,并进行预处理

        Arguments:
            Dataset {ReIDMetaDataset} -- self
            meta_dataset {ReIDMetaDataset} -- 元数据集

        Keyword Arguments:
            with_box {bool} -- [是否使用检测框做crop。从box属性中读取检测框信息] (default: {False})
            with_mask {bool} -- [是否读取mask。为True时从mask_nori_id读取mask] (default: {False})
            mask_num {int} -- [mask数量] (default: {5})
            sub_bg {bool} -- [是否删除背景。with_mask为True时才会生效。将利用第一个mask对图片做背景减除] (default: {False})
            transform {[type]} -- [数据增强] (default: {None})
            preprocessing {[type]} -- [normalize, to tensor等预处理] (default: {None})
        """
        self.meta_dataset = meta_dataset
        self.transform = transform
        self.preprocessing = preprocessing
        self.with_mask = with_mask
        self.mask_num = mask_num

    def read_mask(self, sample):
        # 读入mask
        mask = cv2.imread(sample["mask_path"], cv2.IMREAD_GRAYSCALE)#转换为灰度图
        mask = [mask == v for v in range(self.mask_num)]
        mask = np.stack(mask, axis=-1).astype('float32')
        sample["mask"] = mask


    def __getitem__(self, item):
        meta: dict = self.meta_dataset[item]
        sample = meta.copy()
        # 读入图片
        sample["image"] = read_rgb_image(meta["image_path"])

        # 读入mask
        if self.with_mask:
            self.read_mask(sample)

        # 数据增强
        if self.transform:
            sample = self.transform(**sample)

        # preprocessing
        if self.preprocessing:
            sample = self.preprocessing(**sample)
        
        return sample

    def __len__(self):
        return len(self.meta_dataset)

这就是base.py文件中的主要内容

  • 4
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值