PyTorch创建分割数据集(数据:图像 + 标签:图像)

转载自 https://blog.csdn.net/Teeyohuang/article/details/82108203

前面一篇写创建数据集的博文--- Pytorch创建自己的数据集1 是介绍的应用于图像分类任务的数据集,即输入为一个图像和它的类别数字标签本篇介绍输入的标签label亦为图像的数据集,并包含一些常用的处理手段。比如做图像语义分割时就会用到这种数据输入方式。

1、数据集简介

以VOC2012数据集为例,图像是RGB3通道的,label是1通道的,(其实label原来是几通道的无所谓,只要读取的时候转化成灰度图就行)。

训练数据:

语义label:

这里我们看到label图片都是黑色的,只有白色的轮廓而已

其实是因为label图片里的像素值取值范围是0 ~ 20,即像素点可能的类别共有21类(对此数据集来说),详情如下:

所以对于灰度值0---20来说,我们肉眼看上去就确实都是黑色的,因为灰度值太低了,而白色的轮廓的灰度值是255!

但是这些边界在计算损失值的时候是不作为有效值的,也就是对于灰度值=255的点是忽略的。

如果想看的话,可以用一些色彩变换,对0--20这每一个数字对应一个色彩,就能看出来了,示例如下

这不是重点,只是给大家看一下方便理解而已

2、文本信息

同样有一个文本来指导我对数据的读取,我的信息如下

这其实就是一个记载了图像ID的文本文档,连后缀都没有,但我们依然可以根据这个去数据集中读取相应的image和label

3、代码示例

这个代码是我自己在利用deeplabV2 跑semantic segmentation 任务时写的一个,也许写的并不优美,但反正是可以用的,

可以做个抛砖引玉的目的,对于才入门的朋友,理解这个思路就可,不必照搬我的代码风格……


 
 
  1. import os
  2. import numpy as np
  3. import random
  4. import matplotlib.pyplot as plt
  5. import collections
  6. import torch
  7. import torchvision
  8. import cv2
  9. from PIL import Image
  10. import torchvision.transforms as transforms
  11. from torch.utils import data
  12. class VOCDataSet(data.Dataset):
  13. def __init__(self, root, list_path, crop_size=(321, 321), mean=(104.008, 116.669, 122.675), mirror=True, scale=True, ignore_label=255):
  14. super(VOCDataSet,self).__init__()
  15. self.root = root
  16. self.list_path = list_path
  17. self.crop_h, self.crop_w = crop_size
  18. self.ignore_label = ignore_label
  19. self.mean = np.asarray(mean, np.float32)
  20. self.is_mirror = mirror
  21. self.is_scale = scale
  22. self.img_ids = [i_id.strip() for i_id in open(list_path)]
  23. self.files = []
  24. for name in self.img_ids:
  25. img_file = os.path.join(self.root, "JPEGImages/%s.jpg" % name)
  26. label_file = os.path.join(self.root, "SegmentationClassAug/%s.png" % name)
  27. self.files.append({
  28. "img": img_file,
  29. "label": label_file,
  30. "name": name
  31. })
  32. def __len__(self):
  33. return len(self.files)
  34. def __getitem__(self, index):
  35. datafiles = self.files[index]
  36. '''load the datas'''
  37. name = datafiles[ "name"]
  38. image = Image.open(datafiles[ "img"]).convert( 'RGB')
  39. label = Image.open(datafiles[ "label"]).convert( 'L')
  40. size_origin = image.size # W * H
  41. '''random scale the images and labels'''
  42. if self.is_scale: #如果我在定义dataset时选择了scale=True,就执行本语句对尺度进行随机变换
  43. ratio = 0.5 + random.randint( 0, 11) // 10.0 #0.5~1.5
  44. out_h, out_w = int(size_origin[ 1]*ratio), int(size_origin[ 0]*ratio)
  45. # (H,W)for Resize
  46. image = transforms.Resize((out_h, out_w), Image.LANCZOS)(image)
  47. label = transforms.Resize((out_h, out_w), Image.NEAREST)(label)
  48. '''pad the inputs if their size is smaller than the crop_size'''
  49. pad_w = max(self.crop_w - out_w, 0)
  50. pad_h = max(self.crop_h - out_h, 0)
  51. img_pad = transforms.Pad( padding=( 0, 0,pad_w,pad_h), fill= 0, padding_mode= 'constant')(image)
  52. label_pad = transforms.Pad( padding=( 0, 0,pad_w,pad_h), fill=self.ignore_label, padding_mode= 'constant')(label)
  53. out_size = img_pad.size
  54. '''random crop the inputs'''
  55. if (self.crop_h != 0 or self.crop_w != 0):
  56. #select a random start-point for croping operation
  57. h_off = random.randint( 0, out_size[ 1] - self.crop_h)
  58. w_off = random.randint( 0, out_size[ 0] - self.crop_w)
  59. #crop the image and the label
  60. image = img_pad.crop((w_off,h_off, w_off+self.crop_w, h_off+self.crop_h))
  61. label = label_pad.crop((w_off,h_off, w_off+self.crop_w, h_off+self.crop_h))
  62. '''mirror operation'''
  63. if self.is_mirror:
  64. if np.random.random() < 0.5:
  65. #0:FLIP_LEFT_RIGHT, 1:FLIP_TOP_BOTTOM, 2:ROTATE_90, 3:ROTATE_180, 4:or ROTATE_270.
  66. image = image.transpose( 0)
  67. label = label.transpose( 0)
  68. '''convert PIL Image to numpy array'''
  69. I = np.asarray(image,np.float32) - self.mean
  70. I = I.transpose(( 2, 0, 1)) #transpose the H*W*C to C*H*W
  71. L = np.asarray(np.array(label), np.int64)
  72. #print(I.shape,L.shape)
  73. return I.copy(), L.copy(), np.array(size_origin), name
  74. #这是一个测试函数,也即我的代码写好后,如果直接python运行当前py文件,就会执行以下代码的内容,以检测我上面的代码是否有问题,这其实就是方便我们调试,而不是每次都去run整个网络再看哪里报错
  75. if __name__ == '__main__':
  76. DATA_DIRECTORY = '/home/teeyo/STA/Data/voc_aug/'
  77. DATA_LIST_PATH = '../dataset/list/val.txt'
  78. Batch_size = 4
  79. MEAN = ( 104.008, 116.669, 122.675)
  80. dst = VOCDataSet(DATA_DIRECTORY,DATA_LIST_PATH, mean=( 0, 0, 0))
  81. # just for test, so the mean is (0,0,0) to show the original images.
  82. # But when we are training a model, the mean should have another value
  83. trainloader = data.DataLoader(dst, batch_size = Batch_size)
  84. plt.ion()
  85. for i, data in enumerate(trainloader):
  86. imgs, labels,_,_= data
  87. if i% 1 == 0:
  88. img = torchvision.utils.make_grid(imgs).numpy()
  89. img = img.astype(np.uint8) #change the dtype from float32 to uint8, because the plt.imshow() need the uint8
  90. img = np.transpose(img, ( 1, 2, 0)) #transpose the Channels*H*W to H*W*Channels
  91. #img = img[:, :, ::-1]
  92. plt.imshow(img)
  93. plt.show()
  94. plt.pause( 0.5)
  95. #input()

我个人觉得我应该注释的地方都有相应的注释,虽然有点长, 因为实现了crop和翻转以及scale等功能,但是大家可以下去慢慢揣摩,理解其中的主要思路,与我前一篇的博文Pytorch创建自己的数据集1做对比,那篇博文相当于是提供了最基本的骨架,而这篇就在骨架上长肉生发而已,有疑问的欢迎评论探讨~~

  • 7
    点赞
  • 38
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值