一、什么是训练图像加权
根据样本种类分布使用图像调用频率不同的方法解决。
1、读取训练样本中的GT,保存为一个列表;
2、计算训练样本列表中不同类别个数,然后给每个类别按相应目标框数的倒数赋值,数目越多的种类权重越小,形成按种类的分布直方图;
3、对于训练数据列表,训练时按照类别权重筛选出每类的图像作为训练数据。使用random.choice(population, weights=None, *, cum_weights=None, k=1)更改训练图像索引,可达到样本均衡的效果。
二、代码
计算类别的权重:
def labels_to_class_weights(labels, nc=80):
"""
计算类别权重
Get class weights (inverse frequency) from training labels
输出入:
labels -- 真实的标签列表 [class xywh]
nc -- 类别数
"""
if labels[0] is None: # no labels loaded
return torch.Tensor()
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
classes = labels[:, 0].astype(int) # labels = [class xywh]
# 直方图统计
weights = np.bincount(classes, minlength=nc) # occurrences per class
# Prepend gridpoint count (for uCE training)
# gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
# weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
# 统计的倒数
weights[weights == 0] = 1 # replace empty bins with 1
weights = 1 / weights # number of targets per class
# 归一化为0~1,就是权重
weights /= weights.sum() # normalize
return torch.from_numpy(weights).float()
计算图片权重
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
"""
计算图片的权重
# Produces image weights based on class_weights and image contents
# Usage: index = random.choices(range(n), weights=image_weights, k=1) # weighted image sample
"""
class_counts = []
for x in labels:
class_bin = np.bincount(x[:, 0].astype(int), minlength=nc)
class_counts.append(class_bin)
class_counts = np.array(class_counts)
# 类别权重 * 每张图片中类别数量, 然后在sum
return (class_weights.reshape(1, nc) * class_counts).sum(1)
更新数据集中的图片索引, 使得权重高的图片出现概率大:
# epoch ------------------------------------------------------------------
for epoch in range(self.start_epoch, self.epochs):
# Update image weights 在每次迭代前,更新数据集的indices,使得权重高的图片出现概率大
if self.opt["image_weights"]:
cw = class_weights * (1 - maps) ** 2 / self.nc # class weights
iw = labels_to_image_weights(dataset.labels, nc=self.nc, class_weights=cw) # image weights
dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx