文章目录
1. 计算rank、mAP和mINP
"""eval_metrics.py"""
from __future__ import print_function, absolute_import
import numpy as np
def eval_sysu(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=20):
"""Evaluation with sysu metric
Key: for each query identity, its gallery images from the same camera view are discarded. "Following the original setting in ite dataset"
"""
# distmat传入的是每个特征向量的乘积的负数,gallery(3804*2048)和query转置(2048*301)的矩阵乘积(3804*301)的负数
# q_pids表示query的标签,g_pids表示gallery的标签
# q_camids, g_camids分别表示query和gallery样本的相机标签
# max_rank,统计gallery可能性最大的前多少个
num_q, num_g = distmat.shape # num_q表示query的个数,num_g表示gallery的个数
if num_g < max_rank:
max_rank = num_g
print("Note: number of gallery samples is quite small, got {}".format(num_g))
# 沿着行从小到大排序,返回该数值的原来的索引号(因为传入的是原值的负数,所以矩阵乘积原值中最大的那个数值,序号最小为0)
# 返回原值中,每行数值从大到小的索引值
indices = np.argsort(distmat, axis=1)
# 得到每行(每个query)可能性从大到小的预测标签
pred_label = g_pids[indices]
# 将每行的预测标签和真实标签比较
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) # 扩大真实标签的维度,最终将True/False转换数据类型1/0
# print('matches:', matches, matches.shape) # matches是3804*301维度,每个元素是1或者0.
# compute cmc curve for each query
new_all_cmc = [] # 一种新的方法,存储所有query的[000111..]数组
all_cmc = [] # 原始的方法,存储所有query的[000111...]数组
all_AP = [] # 存储所有query的AP
all_INP = [] # 存储所有query的INP
num_valid_q = 0. # number of valid query(有效的query个数)
# 遍历所有的query样本
for q_idx in range(num_q):
# 1. get query pid and camid
q_pid = q_pids[q_idx]
q_camid = q_camids[q_idx]
# 2. 要在不同位置的摄像机之间进行匹配
# 相机2和相机3在相同的位置,所以相机3的probe图像要跳过相机2的gallery图像
order = indices[q_idx] # 找到这个query对应的gallery可能性排序索引
remove = (q_camid == 3) & (g_camids[order] == 2) # 同时成立则为True,输出301个True或者False
# 取反,False->True。得到的keep中,True是可以使用的gallery。
keep = np.invert(remove)
# 3. compute cmc curve
# the cmc calculation is different from standard protocol
# we follow the protocol of the author's released code
# 去除重复的预测标签
new_cmc = pred_label[q_idx][keep] # 取出这个query行,所有True的预测标签
new_index = np.unique(new_cmc, return_index=True)[1] # 将重复的预测标签只保留一个,返回重复标签的第一个索引下标
new_cmc = [new_cmc[index] for index in sorted(new_index)]
# new_match从找到正确标签开始全是1,之前全是0
new_match = (new_cmc == q_pid).astype(np.int32) # 输出1或者0,1表示与query同ID的预测标签
new_cmc = new_match.cumsum() # 依次输出前k个元素累加和(k=1,2...) 0 0 0 1 1 1 ...
new_all_cmc.append(new_cmc[:max_rank]) # 将该样本的序列添加到所有样本的数组中
# 原始cmc
orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
if not np.any(orig_cmc):
# this condition is true when query identity does not appear in gallery
continue
cmc = orig_cmc.cumsum() # 0 0 0 0 1 1 1 1 2 2 2 2
# 4. compute mINP
# refernece: Deep Learning for Person Re-identification: A Survey and Outlook
pos_idx = np.where(orig_cmc == 1) # 找到正确标签对应的索引
pos_max_idx = np.max(pos_idx) # 找到最大的索引
inp = cmc[pos_max_idx] / (pos_max_idx + 1.0) # 计算INP
all_INP.append(inp)
# 将序列0 0 1 1 2 2 转换为 0 0 1 1 1 1
cmc[cmc > 1] = 1
all_cmc.append(cmc[:max_rank])
num_valid_q += 1.
# 5. compute average precision(AP)
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
num_rel = orig_cmc.sum() # 把该行中所有的1求和,得到正确样本的个数
tmp_cmc = orig_cmc.cumsum() # 累加,得到0 0 1 1 2 2...
# 正确标签,在正样本的位置 / 在所有样本中的位置
tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] # i是索引,x是数值,成对取出
tmp_cmc = np.asarray(tmp_cmc) * orig_cmc # 只保留正确标签的计算结果
AP = tmp_cmc.sum() / num_rel # 计算AP
all_AP.append(AP)
assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
all_cmc = np.asarray(all_cmc).astype(np.float32) # 转变为浮点数
# standard CMC
all_cmc = all_cmc.sum(0) / num_valid_q # 按列相加,除以query个数
# new CMC
new_all_cmc = np.asarray(new_all_cmc).astype(np.float32)
new_all_cmc = new_all_cmc.sum(0) / num_valid_q
mAP = np.mean(all_AP)
mINP = np.mean(all_INP)
return new_all_cmc, mAP, mINP
def eval_regdb(distmat, q_pids, g_pids, max_rank = 20):
# query和gallery的个数
num_q, num_g = distmat.shape
if num_g < max_rank:
max_rank = num_g
print("Note: number of gallery samples is quite small, got {}".format(num_g))
# 传入的负值,相当于依次得到从大到小数值的索引
indices = np.argsort(distmat, axis=1)
matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) # 转换为int32位数据类型
# compute cmc curve for each query
all_cmc = []
all_AP = []
all_INP = []
num_valid_q = 0. # number of valid query
# only two cameras
q_camids = np.ones(num_q).astype(np.int32)
g_camids = 2* np.ones(num_g).astype(np.int32)
for q_idx in range(num_q):
# get query pid and camid
q_pid = q_pids[q_idx]
q_camid = q_camids[q_idx]
# remove gallery samples that have the same pid and camid with query
order = indices[q_idx]
remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
keep = np.invert(remove)
# compute cmc curve
raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
if not np.any(raw_cmc):
# this condition is true when query identity does not appear in gallery
continue
cmc = raw_cmc.cumsum()
# compute mINP
# refernece: Deep Learning for Person Re-identification: A Survey and Outlook
pos_idx = np.where(raw_cmc == 1)
pos_max_idx = np.max(pos_idx)
inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0)
all_INP.append(inp)
cmc[cmc > 1] = 1
all_cmc.append(cmc[:max_rank])
num_valid_q += 1.
# compute average precision
# reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
num_rel = raw_cmc.sum()
tmp_cmc = raw_cmc.cumsum()
tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
AP = tmp_cmc.sum() / num_rel
all_AP.append(AP)
assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
all_cmc = np.asarray(all_cmc).astype(np.float32)
all_cmc = all_cmc.sum(0) / num_valid_q
mAP = np.mean(all_AP)
mINP = np.mean(all_INP)
return all_cmc, mAP, mINP
2. 将屏幕输出的log保存在txt文件中
import os
import sys
class Logger(object):
"""
Write console output to external text file.控制平台的输出
Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py.
"""
def __init__(self, fpath=None):
self.console = sys.stdout
self.file = None
if fpath is not None:
self.file = open(fpath, 'w') # 覆盖之前的内容
def __del__(self):
self.close()
def __enter__(self):
pass
def __exit__(self, *args):
self.close()
def write(self, msg):
self.console.write(msg)
if self.file is not None:
self.file.write(msg)
def flush(self):
self.console.flush()
if self.file is not None:
self.file.flush()
os.fsync(self.file.fileno())
def close(self):
self.console.close()
if self.file is not None:
self.file.close()
# 将保存命令放在最前面
sys.stdout = Logger('out.txt') # sys.stdout不能修改,Logger括号内的文件名称可以更换
for i in range(200):
print(i)
print("the best!")
3. SYSU-MM01训练集预处理为npy
注:给训练做辅助
"""pre_process_sysu.py"""
# 1.导入库
import numpy as np
from PIL import Image
import os
# 2.设置路径
data_path = 'SYSU-MM01/'
rgb_cameras = ['cam1', 'cam2', 'cam4', 'cam5']
ir_cameras = ['cam3', 'cam6']
# 3.找到训练集和验证集的ID,并预处理
# 找到划分协议文件
file_path_train = os.path.join(data_path, 'exp/train_id.txt')
file_path_val = os.path.join(data_path, 'exp/val_id.txt')
with open(file_path_train, 'r') as file:
ids = file.read().splitlines() # splitlines按照行('\r', '\r\n', \n')分隔,返回一个包含各行作为元素的列表
ids = [int(y) for y in ids[0].split(',')] # ids[0]去掉引号,然后按逗号分隔,形成整数型列表
id_train = ["%04d" % x for x in ids] # 转变成4位整数,与数据集文件夹的ID形式对应
# print(id_train) # ['0001', '0002', '0004', '0005', '0007'...'533']
with open(file_path_val, 'r') as file:
ids = file.read().splitlines()
ids = [int(y) for y in ids[0].split(',')]
id_val = ["%04d" % x for x in ids]
# print(id_val) # ['0334', '0335', '0336', '0337', '0338'...]
# 合并训练ID和测试ID
id_train.extend(id_val)
# 4.找到训练集ID对应的所有RGB和IR图像
files_rgb = [] # 保存RGB图像路径
files_ir = [] # 保存IR图像路径
for id in sorted(id_train): # 按着ID从小到大的顺序
for cam in rgb_cameras: # 遍历RGB相机
img_dir = os.path.join(data_path, cam, id) # 该相机下该ID的文件夹
if os.path.isdir(img_dir):
# 形成该ID所有图像的绝对路径
new_files = sorted([img_dir + '/' + i for i in os.listdir(img_dir)]) # os.listdir 返回包含的文件名字的列表
files_rgb.extend(new_files)
for cam in ir_cameras:
img_dir = os.path.join(data_path, cam, id)
if os.path.isdir(img_dir):
new_files = sorted([img_dir + '/' + i for i in os.listdir(img_dir)])
files_ir.extend(new_files)
# 5.重新打标签(将混乱的标签一一对应变为0-394的新标签)
pid_container = set() # 集合可以保证里面的元素不重复
# 遍历IR图像,得到ID从小到大的集合
for img_path in files_ir:
pid = int(img_path[-13:-9]) # 取出图像行人ID
pid_container.add(pid)
# 构建字典
pid2label = {pid: label for label, pid in enumerate(pid_container)} # pid行人ID,label索引
fix_image_width = 144
fix_image_height = 288
# 6.读出图片,保存为numpy数组
def read_imgs(train_image):
train_img = []
train_label = []
for img_path in train_image: # 遍历所有的RGB或者IR图像路径
# img
img = Image.open(img_path) # 打开图像
img = img.resize((fix_image_width, fix_image_height), Image.ANTIALIAS) # 整形
pix_array = np.array(img) # 转换为numpy数组
train_img.append(pix_array) # 加到图像列表中
# label
pid = int(img_path[-13:-9]) # 取出persion ID
pid = pid2label[pid] # 取出ID对应的索引
train_label.append(pid) # 把索引加到标签列表中
return np.array(train_img), np.array(train_label)
# rgb imges
train_img, train_label = read_imgs(files_rgb)
np.save(data_path + 'train_rgb_resized_img.npy', train_img)
np.save(data_path + 'train_rgb_resized_label.npy', train_label)
# ir imges
train_img, train_label = read_imgs(files_ir)
np.save(data_path + 'train_ir_resized_img.npy', train_img)
np.save(data_path + 'train_ir_resized_label.npy', train_label)
4. SYSU-MM01和RegDB生成query和gallery图像
注意:图像返回的都是图像路径,标签是npy格式。给测试做辅助。
""" data_manager.py """
from __future__ import print_function, absolute_import
import os
import numpy as np
import random
# sysu生成query
def process_query_sysu(data_path, mode='all', relabel=False):
# 两种搜索模式,query都是一样的
if mode == 'all':
ir_cameras = ['cam3', 'cam6']
elif mode == 'indoor':
ir_cameras = ['cam3', 'cam6']
file_path = os.path.join(data_path, 'exp/test_id.txt') # 打开测试的txt文件
files_ir = [] # 保存所有IR图像的路径
# 读出测试的ID
with open(file_path, 'r') as file:
ids = file.read().splitlines()
ids = [int(y) for y in ids[0].split(',')]
ids = ["%04d" % x for x in ids] # 化为4位整数,与数据集文件名称对应
# 依次将每个ID在每个相机中的所有图像找出来
for id in sorted(ids): # sorted对列表中的元素排序
for cam in ir_cameras: # 遍历相机(3、6)
img_dir = os.path.join(data_path, cam, id)
if os.path.isdir(img_dir):
new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)])
files_ir.extend(new_files)
query_img = [] # 保存所有图像的路径
query_id = [] # 保存所有图像的person ID
query_cam = [] # 保存所有图像的cam ID
for img_path in files_ir:
camid, pid = int(img_path[-15]), int(img_path[-13:-9]) # 在名字中取出相机ID和personID
query_img.append(img_path)
query_id.append(pid)
query_cam.append(camid)
return query_img, np.array(query_id), np.array(query_cam)
# sysu生成gallery
def process_gallery_sysu(data_path, mode='all', trial=0, relabel=False):
random.seed(trial) # 控制随机选择gallery
# 两种搜索模式
if mode == 'all':
rgb_cameras = ['cam1', 'cam2', 'cam4', 'cam5']
elif mode == 'indoor':
rgb_cameras = ['cam1', 'cam2']
file_path = os.path.join(data_path, 'exp/test_id.txt') # 找到测试txt文件
files_rgb = [] # 存储RGB图像
# 读出测试的行人ID
with open(file_path, 'r') as file:
ids = file.read().splitlines()
ids = [int(y) for y in ids[0].split(',')]
ids = ["%04d" % x for x in ids]
# 根据ID找到每个相机下的所有照片,并选择一个作为gallery(single shoot)
for id in sorted(ids):
for cam in rgb_cameras:
img_dir = os.path.join(data_path, cam, id)
if os.path.isdir(img_dir):
new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)])
files_rgb.append(random.choice(new_files)) # single shoot:一个ID在一个cam中只随机选择一个作为gallery
gall_img = []
gall_id = []
gall_cam = []
for img_path in files_rgb:
camid, pid = int(img_path[-15]), int(img_path[-13:-9])
gall_img.append(img_path)
gall_id.append(pid)
gall_cam.append(camid)
return gall_img, np.array(gall_id), np.array(gall_cam)
# regdb生成测试集(gallery和query)
def process_test_regdb(img_dir, trial=1, modal='visible'):
# 两种搜索方法(V-IR & IR-V),打开包含样本和标签的txt文件
if modal == 'visible':
input_data_path = img_dir + 'idx/test_visible_{}'.format(trial) + '.txt'
elif modal == 'thermal':
input_data_path = img_dir + 'idx/test_thermal_{}'.format(trial) + '.txt'
# 分别得到所有的图像和标签
with open(input_data_path) as f:
data_file_list = open(input_data_path, 'rt').read().splitlines() # 按行分割成不同的元素,形成列表
# Get full list of image and labels
file_image = [img_dir + '/' + s.split(' ')[0] for s in data_file_list] # 图像路径
file_label = [int(s.split(' ')[1]) for s in data_file_list] # 图像标签
return file_image, np.array(file_label)
5. 将某几条print内容保存到txt文件
test_log_file = open('hh.txt', "w") # 新建txt文件,若文件存在则覆盖原本内容
for epoch in range(10):
print('Test Epoch: {}'.format(epoch), file=test_log_file)
# 注意,这种格式的print,并不在屏幕打印输出