ReID基础 | 跨模态reid中常用的固定代码

本文介绍了如何在eval_metrics.py中计算SYSU-MM01的rank、mAP和mINP指标,同时展示了如何保存log到txt文件、预处理SYSU-MM01数据为npy格式,以及处理SYSU-MM01和RegDB的query和gallery图像。
摘要由CSDN通过智能技术生成

1. 计算rank、mAP和mINP

"""eval_metrics.py"""

from __future__ import print_function, absolute_import
import numpy as np


def eval_sysu(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=20):
    """Evaluation with sysu metric
    Key: for each query identity, its gallery images from the same camera view are discarded. "Following the original setting in ite dataset"
    """
    # distmat传入的是每个特征向量的乘积的负数,gallery(3804*2048)和query转置(2048*301)的矩阵乘积(3804*301)的负数
    # q_pids表示query的标签,g_pids表示gallery的标签
    # q_camids, g_camids分别表示query和gallery样本的相机标签
    # max_rank,统计gallery可能性最大的前多少个

    num_q, num_g = distmat.shape    # num_q表示query的个数,num_g表示gallery的个数

    if num_g < max_rank:
        max_rank = num_g
        print("Note: number of gallery samples is quite small, got {}".format(num_g))

    # 沿着行从小到大排序,返回该数值的原来的索引号(因为传入的是原值的负数,所以矩阵乘积原值中最大的那个数值,序号最小为0)
    # 返回原值中,每行数值从大到小的索引值
    indices = np.argsort(distmat, axis=1)
    # 得到每行(每个query)可能性从大到小的预测标签
    pred_label = g_pids[indices]
    # 将每行的预测标签和真实标签比较
    matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)   # 扩大真实标签的维度,最终将True/False转换数据类型1/0
    # print('matches:', matches, matches.shape)       # matches是3804*301维度,每个元素是1或者0.
    
    # compute cmc curve for each query
    new_all_cmc = []    # 一种新的方法,存储所有query的[000111..]数组
    all_cmc = []        # 原始的方法,存储所有query的[000111...]数组
    all_AP = []         # 存储所有query的AP
    all_INP = []        # 存储所有query的INP
    num_valid_q = 0.    # number of valid query(有效的query个数)

    # 遍历所有的query样本
    for q_idx in range(num_q):
        # 1. get query pid and camid
        q_pid = q_pids[q_idx]
        q_camid = q_camids[q_idx]

        # 2. 要在不同位置的摄像机之间进行匹配
        # 相机2和相机3在相同的位置,所以相机3的probe图像要跳过相机2的gallery图像
        order = indices[q_idx]  # 找到这个query对应的gallery可能性排序索引
        remove = (q_camid == 3) & (g_camids[order] == 2)    # 同时成立则为True,输出301个True或者False
        # 取反,False->True。得到的keep中,True是可以使用的gallery。
        keep = np.invert(remove)

        # 3. compute cmc curve
        # the cmc calculation is different from standard protocol
        # we follow the protocol of the author's released code

        # 去除重复的预测标签
        new_cmc = pred_label[q_idx][keep]       # 取出这个query行,所有True的预测标签
        new_index = np.unique(new_cmc, return_index=True)[1]  # 将重复的预测标签只保留一个,返回重复标签的第一个索引下标
        new_cmc = [new_cmc[index] for index in sorted(new_index)]

        # new_match从找到正确标签开始全是1,之前全是0
        new_match = (new_cmc == q_pid).astype(np.int32)     # 输出1或者0,1表示与query同ID的预测标签
        new_cmc = new_match.cumsum()        # 依次输出前k个元素累加和(k=1,2...) 0 0 0 1 1 1 ...
        new_all_cmc.append(new_cmc[:max_rank])  # 将该样本的序列添加到所有样本的数组中

        # 原始cmc
        orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
        if not np.any(orig_cmc):
            # this condition is true when query identity does not appear in gallery
            continue
        cmc = orig_cmc.cumsum()         # 0 0 0 0 1 1 1 1 2 2 2 2

        # 4. compute mINP
        # refernece: Deep Learning for Person Re-identification: A Survey and Outlook
        pos_idx = np.where(orig_cmc == 1)               # 找到正确标签对应的索引
        pos_max_idx = np.max(pos_idx)                   # 找到最大的索引
        inp = cmc[pos_max_idx] / (pos_max_idx + 1.0)    # 计算INP
        all_INP.append(inp)

        # 将序列0 0 1 1 2 2  转换为 0 0 1 1 1 1
        cmc[cmc > 1] = 1
        all_cmc.append(cmc[:max_rank])
        num_valid_q += 1.

        # 5. compute average precision(AP)
        # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
        num_rel = orig_cmc.sum()        # 把该行中所有的1求和,得到正确样本的个数
        tmp_cmc = orig_cmc.cumsum()     # 累加,得到0 0 1 1 2 2...

        # 正确标签,在正样本的位置 / 在所有样本中的位置
        tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]   # i是索引,x是数值,成对取出
        tmp_cmc = np.asarray(tmp_cmc) * orig_cmc    # 只保留正确标签的计算结果
        AP = tmp_cmc.sum() / num_rel        # 计算AP
        all_AP.append(AP)

    assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
    
    all_cmc = np.asarray(all_cmc).astype(np.float32)            # 转变为浮点数
    # standard CMC
    all_cmc = all_cmc.sum(0) / num_valid_q              # 按列相加,除以query个数
    # new CMC
    new_all_cmc = np.asarray(new_all_cmc).astype(np.float32)
    new_all_cmc = new_all_cmc.sum(0) / num_valid_q

    mAP = np.mean(all_AP)
    mINP = np.mean(all_INP)
    return new_all_cmc, mAP, mINP

    
def eval_regdb(distmat, q_pids, g_pids, max_rank = 20):
    # query和gallery的个数
    num_q, num_g = distmat.shape
    if num_g < max_rank:
        max_rank = num_g
        print("Note: number of gallery samples is quite small, got {}".format(num_g))
    # 传入的负值,相当于依次得到从大到小数值的索引
    indices = np.argsort(distmat, axis=1)
    matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)   # 转换为int32位数据类型


    # compute cmc curve for each query
    all_cmc = []
    all_AP = []
    all_INP = []
    num_valid_q = 0. # number of valid query
    
    # only two cameras
    q_camids = np.ones(num_q).astype(np.int32)
    g_camids = 2* np.ones(num_g).astype(np.int32)
    
    for q_idx in range(num_q):
        # get query pid and camid
        q_pid = q_pids[q_idx]
        q_camid = q_camids[q_idx]

        # remove gallery samples that have the same pid and camid with query
        order = indices[q_idx]
        remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
        keep = np.invert(remove)

        # compute cmc curve
        raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
        if not np.any(raw_cmc):
            # this condition is true when query identity does not appear in gallery
            continue

        cmc = raw_cmc.cumsum()

        # compute mINP
        # refernece: Deep Learning for Person Re-identification: A Survey and Outlook
        pos_idx = np.where(raw_cmc == 1)
        pos_max_idx = np.max(pos_idx)
        inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0)
        all_INP.append(inp)

        cmc[cmc > 1] = 1

        all_cmc.append(cmc[:max_rank])
        num_valid_q += 1.

        # compute average precision
        # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
        num_rel = raw_cmc.sum()
        tmp_cmc = raw_cmc.cumsum()
        tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
        tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
        AP = tmp_cmc.sum() / num_rel
        all_AP.append(AP)

    assert num_valid_q > 0, "Error: all query identities do not appear in gallery"

    all_cmc = np.asarray(all_cmc).astype(np.float32)
    all_cmc = all_cmc.sum(0) / num_valid_q
    mAP = np.mean(all_AP)
    mINP = np.mean(all_INP)
    return all_cmc, mAP, mINP

2. 将屏幕输出的log保存在txt文件中

import os
import sys

class Logger(object):
    """
    Write console output to external text file.控制平台的输出
    Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py.
    """
    def __init__(self, fpath=None):
        self.console = sys.stdout
        self.file = None
        if fpath is not None:
            self.file = open(fpath, 'w')		# 覆盖之前的内容

    def __del__(self):
        self.close()

    def __enter__(self):
        pass

    def __exit__(self, *args):
        self.close()

    def write(self, msg):
        self.console.write(msg)
        if self.file is not None:
            self.file.write(msg)

    def flush(self):
        self.console.flush()
        if self.file is not None:
            self.file.flush()
            os.fsync(self.file.fileno())

    def close(self):
        self.console.close()
        if self.file is not None:
            self.file.close()

# 将保存命令放在最前面
sys.stdout = Logger('out.txt')		# sys.stdout不能修改,Logger括号内的文件名称可以更换

for i in range(200):
    print(i)
print("the best!")

3. SYSU-MM01训练集预处理为npy

注:给训练做辅助

"""pre_process_sysu.py"""

# 1.导入库
import numpy as np
from PIL import Image
import os

# 2.设置路径
data_path = 'SYSU-MM01/'
rgb_cameras = ['cam1', 'cam2', 'cam4', 'cam5']
ir_cameras = ['cam3', 'cam6']

# 3.找到训练集和验证集的ID,并预处理

# 找到划分协议文件
file_path_train = os.path.join(data_path, 'exp/train_id.txt')
file_path_val = os.path.join(data_path, 'exp/val_id.txt')

with open(file_path_train, 'r') as file:
    ids = file.read().splitlines()              # splitlines按照行('\r', '\r\n', \n')分隔,返回一个包含各行作为元素的列表
    ids = [int(y) for y in ids[0].split(',')]   # ids[0]去掉引号,然后按逗号分隔,形成整数型列表
    id_train = ["%04d" % x for x in ids]        # 转变成4位整数,与数据集文件夹的ID形式对应
    # print(id_train)                           # ['0001', '0002', '0004', '0005', '0007'...'533']

with open(file_path_val, 'r') as file:
    ids = file.read().splitlines()
    ids = [int(y) for y in ids[0].split(',')]
    id_val = ["%04d" % x for x in ids]
    # print(id_val)                             # ['0334', '0335', '0336', '0337', '0338'...]

# 合并训练ID和测试ID
id_train.extend(id_val)

# 4.找到训练集ID对应的所有RGB和IR图像
files_rgb = []      # 保存RGB图像路径
files_ir = []       # 保存IR图像路径

for id in sorted(id_train):     # 按着ID从小到大的顺序
    for cam in rgb_cameras:     # 遍历RGB相机
        img_dir = os.path.join(data_path, cam, id)  # 该相机下该ID的文件夹
        if os.path.isdir(img_dir):
            # 形成该ID所有图像的绝对路径
            new_files = sorted([img_dir + '/' + i for i in os.listdir(img_dir)])    # os.listdir 返回包含的文件名字的列表
            files_rgb.extend(new_files)

    for cam in ir_cameras:
        img_dir = os.path.join(data_path, cam, id)
        if os.path.isdir(img_dir):
            new_files = sorted([img_dir + '/' + i for i in os.listdir(img_dir)])
            files_ir.extend(new_files)

# 5.重新打标签(将混乱的标签一一对应变为0-394的新标签)
pid_container = set()   # 集合可以保证里面的元素不重复
# 遍历IR图像,得到ID从小到大的集合
for img_path in files_ir:
    pid = int(img_path[-13:-9])     # 取出图像行人ID
    pid_container.add(pid)
# 构建字典
pid2label = {pid: label for label, pid in enumerate(pid_container)}     # pid行人ID,label索引


fix_image_width = 144
fix_image_height = 288

# 6.读出图片,保存为numpy数组
def read_imgs(train_image):
    train_img = []
    train_label = []
    for img_path in train_image:    # 遍历所有的RGB或者IR图像路径
        # img
        img = Image.open(img_path)                                              # 打开图像
        img = img.resize((fix_image_width, fix_image_height), Image.ANTIALIAS)  # 整形
        pix_array = np.array(img)                                               # 转换为numpy数组
        train_img.append(pix_array)                                             # 加到图像列表中

        # label
        pid = int(img_path[-13:-9])     # 取出persion ID
        pid = pid2label[pid]            # 取出ID对应的索引
        train_label.append(pid)         # 把索引加到标签列表中
    return np.array(train_img), np.array(train_label)


# rgb imges
train_img, train_label = read_imgs(files_rgb)
np.save(data_path + 'train_rgb_resized_img.npy', train_img)
np.save(data_path + 'train_rgb_resized_label.npy', train_label)

# ir imges
train_img, train_label = read_imgs(files_ir)
np.save(data_path + 'train_ir_resized_img.npy', train_img)
np.save(data_path + 'train_ir_resized_label.npy', train_label)

4. SYSU-MM01和RegDB生成query和gallery图像

注意:图像返回的都是图像路径,标签是npy格式。给测试做辅助。

""" data_manager.py """

from __future__ import print_function, absolute_import
import os
import numpy as np
import random


# sysu生成query
def process_query_sysu(data_path, mode='all', relabel=False):
    # 两种搜索模式,query都是一样的
    if mode == 'all':
        ir_cameras = ['cam3', 'cam6']
    elif mode == 'indoor':
        ir_cameras = ['cam3', 'cam6']
    
    file_path = os.path.join(data_path, 'exp/test_id.txt')  # 打开测试的txt文件
    files_ir = []       # 保存所有IR图像的路径
    
    # 读出测试的ID
    with open(file_path, 'r') as file:
        ids = file.read().splitlines()
        ids = [int(y) for y in ids[0].split(',')]
        ids = ["%04d" % x for x in ids]     # 化为4位整数,与数据集文件名称对应
    
    # 依次将每个ID在每个相机中的所有图像找出来
    for id in sorted(ids):      # sorted对列表中的元素排序
        for cam in ir_cameras:  # 遍历相机(3、6)
            img_dir = os.path.join(data_path, cam, id)
            if os.path.isdir(img_dir):
                new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)])
                files_ir.extend(new_files)

    query_img = []  # 保存所有图像的路径
    query_id = []   # 保存所有图像的person ID
    query_cam = []  # 保存所有图像的cam ID

    for img_path in files_ir:
        camid, pid = int(img_path[-15]), int(img_path[-13:-9])  # 在名字中取出相机ID和personID
        query_img.append(img_path)
        query_id.append(pid)
        query_cam.append(camid)
    return query_img, np.array(query_id), np.array(query_cam)


# sysu生成gallery
def process_gallery_sysu(data_path, mode='all', trial=0, relabel=False):
    
    random.seed(trial)      # 控制随机选择gallery
    
    # 两种搜索模式
    if mode == 'all':
        rgb_cameras = ['cam1', 'cam2', 'cam4', 'cam5']
    elif mode == 'indoor':
        rgb_cameras = ['cam1', 'cam2']
    
    file_path = os.path.join(data_path, 'exp/test_id.txt')  # 找到测试txt文件
    files_rgb = []      # 存储RGB图像
    
    # 读出测试的行人ID
    with open(file_path, 'r') as file:
        ids = file.read().splitlines()
        ids = [int(y) for y in ids[0].split(',')]
        ids = ["%04d" % x for x in ids]
    
    # 根据ID找到每个相机下的所有照片,并选择一个作为gallery(single shoot)
    for id in sorted(ids):
        for cam in rgb_cameras:
            img_dir = os.path.join(data_path, cam, id)
            if os.path.isdir(img_dir):
                new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)])
                files_rgb.append(random.choice(new_files))      # single shoot:一个ID在一个cam中只随机选择一个作为gallery

    gall_img = []
    gall_id = []
    gall_cam = []
    for img_path in files_rgb:
        camid, pid = int(img_path[-15]), int(img_path[-13:-9])
        gall_img.append(img_path)
        gall_id.append(pid)
        gall_cam.append(camid)
    return gall_img, np.array(gall_id), np.array(gall_cam)


# regdb生成测试集(gallery和query)
def process_test_regdb(img_dir, trial=1, modal='visible'):
    # 两种搜索方法(V-IR & IR-V),打开包含样本和标签的txt文件
    if modal == 'visible':
        input_data_path = img_dir + 'idx/test_visible_{}'.format(trial) + '.txt'    
    elif modal == 'thermal':
        input_data_path = img_dir + 'idx/test_thermal_{}'.format(trial) + '.txt'    
    
    # 分别得到所有的图像和标签
    with open(input_data_path) as f:
        data_file_list = open(input_data_path, 'rt').read().splitlines()        # 按行分割成不同的元素,形成列表
        # Get full list of image and labels
        file_image = [img_dir + '/' + s.split(' ')[0] for s in data_file_list]  # 图像路径
        file_label = [int(s.split(' ')[1]) for s in data_file_list]             # 图像标签
        
    return file_image, np.array(file_label)

5. 将某几条print内容保存到txt文件

test_log_file = open('hh.txt', "w")			# 新建txt文件,若文件存在则覆盖原本内容
for epoch in range(10):
    print('Test Epoch: {}'.format(epoch), file=test_log_file)
# 注意,这种格式的print,并不在屏幕打印输出
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值