使用plt进行图像、mask、图像和mask叠加显示
import matplotlib.pyplot as plt
# 随机采样索引
sample_ids = [random.randint(0, len(train_img_paths)) for _ in range(24)]
# 遍历
for sample_id in sample_ids:
# 获取id
data_name = train_df.loc[sample_id]["id"]
# 获取图像和掩膜
img, msk = train_dataset[sample_id]
# 对图像进行维度变化,并*255
img = img.permute((1, 2, 0)).numpy()*255.0
# 转为unint8类型
img = img.astype('uint8')
# 掩膜*255
msk = (msk*255).numpy().astype('uint8')
# print(np.max(msk))
# img= Image.fromarray(msk)
# img.show()
# 显示格局设置
plt.figure(figsize=(9, 4))
# 打印id
print(data_name)
# 不显示轴
plt.axis('off')
# 显示第一张位置
plt.subplot(1,3,1)
plt.imshow(img)
# 显示第二张位置
plt.subplot(1,3,2)
plt.imshow(msk)
# 显示第三张位置
plt.subplot(1,3,3)
# 第一层显示图像
plt.imshow(img, cmap='bone')
# 第二层显示mask
plt.imshow(msk, alpha=0.5)
plt.show()