1. 下载数据集
下载好之后将其解压放在工程根目录下
2. 数据集可视化观察
"""
==========================================
@author: Seaton
@Time: 2023/8/15:15:46
@IDE: PyCharm
@Summary:可视化探索西瓜语义分割数据集
==========================================
"""
import os
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
# 指定单张图像路径
img_path = 'Watermelon87_Semantic_Seg_Mask/img_dir/train/04_35-2.jpg'
mask_path = 'Watermelon87_Semantic_Seg_Mask/ann_dir/train/04_35-2.png'
img = cv2.imread(img_path)
mask = cv2.imread(mask_path)
# 找一张图输出大小看看
print(f'原图大小:{img.shape}; mask大小:{mask.shape}')
# 输出分割标注所有类别
print(np.unique(mask))
# 每个类别的 BGR 配色
palette = [
['background', [127, 127, 127]],
['red', [0, 0, 200]],
['green', [0, 200, 0]],
['white', [144, 238, 144]],
['seed-black', [30, 30, 30]],
['seed-white', [8, 189, 251]]
]
# 转换为字典
palette_dict = {}
for idx, each in enumerate(palette):
palette_dict[idx] = each[1]
print(palette_dict)
# 取mask第一个通道
mask = mask[:, :, 0]
# 将预测的整数ID,映射为对应类别的颜色
viz_mask_bgr = np.zeros((mask.shape[0], mask.shape[1], 3))
for idx in palette_dict.keys():
viz_mask_bgr[np.where(mask == idx)] = palette_dict[idx]
viz_mask_bgr = viz_mask_bgr.astype('uint8')
# 将语义分割标注图和原图叠加显示
opacity = 0.5 # 透明度越大,可视化效果越接近原图
label_viz = cv2.addWeighted(img, opacity, viz_mask_bgr, 1 - opacity, 0)
plt.imshow(label_viz[:, :, ::-1])
plt.show()
# 批量可视化
# 指定图像和标注路径
PATH_IMAGE = 'Watermelon87_Semantic_Seg_Mask/img_dir/train'
PATH_MASKS = 'Watermelon87_Semantic_Seg_Mask/ann_dir/train'
# n 行 n 列可视化
n = 5
fig, axes = plt.subplots(nrows=n, ncols=n, sharex=True, figsize=(12, 12))
for i, file_name in enumerate(os.listdir(PATH_IMAGE)[:n ** 2]): # 前25个
# 载入图像和标注
img_path = os.path.join(PATH_IMAGE, file_name)
mask_path = os.path.join(PATH_MASKS, file_name.split('.')[0] + '.png')
img = cv2.imread(img_path)
mask = cv2.imread(mask_path)
mask = mask[:, :, 0]
# 将预测的整数ID,映射为对应类别的颜色
viz_mask_bgr = np.zeros((mask.shape[0], mask.shape[1], 3))
for idx in palette_dict.keys():
viz_mask_bgr[np.where(mask == idx)] = palette_dict[idx]
viz_mask_bgr = viz_mask_bgr.astype('uint8')
# 将语义分割标注图和原图叠加显示
label_viz = cv2.addWeighted(img, opacity, viz_mask_bgr, 1 - opacity, 0)
# 可视化
axes[i // n, i % n].imshow(label_viz[:, :, ::-1])
axes[i // n, i % n].axis('off') # 关闭坐标轴显示
fig.suptitle('Image and Semantic Label', fontsize=30)
plt.tight_layout()
# plt.savefig('outputs/D-2.jpg')
plt.show()
这里我注释掉了89行,如果想保存结果文件到本地的话就新建一个文件夹叫output
即可。