先确定角度的取值范围,代码有注意事项
@try_except # known issue https://github.com/ultralytics/yolov5/issues/5395
@Timeout(30) # known issue https://github.com/ultralytics/yolov5/issues/5611
def plot_labels(labels, names=(), save_dir=Path(''), img_size=1024):
rboxes = poly2rbox(labels[:, 1:])
labels = np.concatenate((labels[:, :1], rboxes[:, :-1]), axis=1) # [cls xyls]
# 定义子区间的数量
num_intervals = 18
# 计算每个子区间的宽度
interval_width = np.pi / num_intervals
# 创建子区间对应的列表
interval_lists = [[] for _ in range(num_intervals)]
# 对rboxes进行遍历
print(rboxes)
for j, rbox in enumerate(rboxes.tolist()):
"""重点:提取rbox中的θ,判断θ值所在区间,例如(-pi/2,pi/2],(0,179]"""
theta = rbox[-1] # 默认(-pi/2,pi/2]
# 判断θ在哪个子区间内
for i in range(num_intervals):
lower_bound = -np.pi / 2 + i * interval_width
upper_bound = -np.pi / 2 + (i + 1) * interval_width
if lower_bound <= theta < upper_bound:
# 将θ放入对应的子区间列表中
interval_lists[i].append(theta)
break
# 打印每个子区间的列表
for i, interval_list in enumerate(interval_lists):
print(f"Interval {i+1}: {len(interval_list)}")
# 绘制直方图
plt.figure(figsize=(10, 6)) # 设置图形大小
for i, interval_list in enumerate(interval_lists):
plt.bar(f'Interval {i+1}', len(interval_list), alpha=0.7)
# 添加标题和标签
plt.title('角度')
# plt.xlabel('Intervals')
# plt.ylabel('Length of Lists')
# 设置特定位置的刻度文本旋转角度
plt.xticks(rotation='vertical') # 默认所有刻度文本都垂直显示
plt.xticks([0, 4, 8, 12, 16], ['-90', '-45', '0', '45', '90']) # 设置特定位置的刻度文本
# 保存直方图到指定位置,保持刻度文本平行于x轴
plt.tight_layout() # 调整布局,防止标签重叠
plt.savefig('./histogram.png', bbox_inches='tight') # 替换为要保存的位置和文件名,并设置bbox_inches为'tight'
plt.close() # 关闭图形,确保不会显示在屏幕上
# plot dataset labels
LOGGER.info(f"Plotting labels to {save_dir / 'labels_xyls.jpg'}... ")
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, hboxes(xyls)
nc = int(c.max() + 1) # number of classes
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'long_edge', 'short_edge'])
# seaborn correlogram
sn.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
plt.close()
# matplotlib labels
matplotlib.use('svg') # faster
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
# [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195
ax[0].set_ylabel('instances')
if 0 < len(names) < 30:
ax[0].set_xticks(range(len(names)))
ax[0].set_xticklabels(names, rotation=90, fontsize=10)
else:
ax[0].set_xlabel('classes')
sn.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
sn.histplot(x, x='long_edge', y='short_edge', ax=ax[3], bins=50, pmax=0.9)
# rectangles
# labels[:, 1:3] = 0.5 # center
labels[:, 1:3] = 0.5 * img_size # center
# labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000
labels[:, 1:] = xywh2xyxy(labels[:, 1:])
# img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255)
img = Image.fromarray(np.ones((img_size, img_size, 3), dtype=np.uint8) * 255)
for cls, *box in labels[:1000]:
ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
ax[1].imshow(img)
ax[1].axis('off')
for a in [0, 1, 2, 3]:
for s in ['top', 'right', 'left', 'bottom']:
ax[a].spines[s].set_visible(False)
plt.savefig(save_dir / 'labels_xyls.jpg', dpi=200)
matplotlib.use('Agg')
plt.close()
通过YOLOv5 OBB的直方图可视化代码修改,增加了一个角度不平衡的可视化结果,可参照2023年论文的图
UCAS-AOD
(7096, 5)
Interval 1: 627
Interval 2: 370
Interval 3: 259
Interval 4: 279
Interval 5: 292
Interval 6: 330
Interval 7: 351
Interval 8: 376
Interval 9: 560
Interval 10: 641
Interval 11: 422
Interval 12: 440
Interval 13: 305
Interval 14: 274
Interval 15: 306
Interval 16: 284
Interval 17: 449
Interval 18: 531
/hy-tmp/v5obb/utils/plots.py:490: UserWarning: Glyph 35282 (\N{CJK UNIFIED IDEOGRAPH-89D2}) missing from current font.
plt.tight_layout() # 调整布局,防止标签重叠
/hy-tmp/v5obb/utils/plots.py:490: UserWarning: Glyph 24230 (\N{CJK UNIFIED IDEOGRAPH-5EA6}) missing from current font.
plt.tight_layout() # 调整布局,防止标签重叠
/hy-tmp/v5obb/utils/plots.py:491: UserWarning: Glyph 35282 (\N{CJK UNIFIED IDEOGRAPH-89D2}) missing from current font.
plt.savefig('./histogram.png', bbox_inches='tight') # 替换为要保存的位置和文件名,并设置bbox_inches为'tight'
/hy-tmp/v5obb/utils/plots.py:491: UserWarning: Glyph 24230 (\N{CJK UNIFIED IDEOGRAPH-5EA6}) missing from current font.
plt.savefig('./histogram.png', bbox_inches='tight') # 替换为要保存的位置和文件名,并设置bbox_inches为'tight'
Plotting labels to runs/train/exp9/labels_xyls.jpg...