#CamVid数据集介绍及读取,附代码(Pytorch版本)
最近想在Bisenet网络上测试一下CamVid数据集,CamVid数据集也是第一次接触,然后想借阅一些别人的博客,抄近道,一查之后才发现。。。。。。。。
所以我想把自己理解的跟大家分享一下。
我也会把自己跑通的Bisenet代码也会分享到该博客上。
##数据集简介
CamVid数据集的文件布局如下所示:
其中class_dict是将数据集中每个物体使用RGB三通道的颜色值进行分类的。
这里面的class_11说的是数据集中存在的,且进行分割的物体。(这个是我个人理解,如果有误,欢迎各位在评论区留言。)
这里说了只是简介,想了解更多,只需要自己下载,查看一下便知。
读取代码解析
上代码!!!!!
下面的是保存为CamVid.py
import os
import torch
import glob
import os
from torchvision import transforms
#import cv2
from PIL import Image
import pandas as pd
import numpy as np
#from imgaug import augmenters as iaa
#import imgaug as ia
from utils import get_label_info, one_hot_it, RandomCrop, reverse_one_hot, one_hot_it_v11, one_hot_it_v11_dice
import random
def augmentation():
# augment images with spatial transformation: Flip, Affine, Rotation, etc...
# see https://github.com/aleju/imgaug for more details
pass
def augmentation_pixel():
# augment images with pixel intensity transformation: GaussianBlur, Multiply, etc...
pass
class CamVid(torch.utils.data.Dataset):
def __init__(self, image_path, label_path, csv_path, scale, loss='dice', mode='train'):
super().__init__()
self.mode = mode
self.image_list = []
if not isinstance(image_path, list):
image_path = [image_path]
for image_path_ in image_path:
self.image_list.extend(glob.glob(os.path.join(image_path_, '*.png')))
self.image_list.sort()
self.label_list = []
if not isinstance(label_path, list):
label_path = [label_path]
for label_path_ in label_path:
self.label_list.extend(glob.glob(os.path.join(label_path_, '*.png')))
self.label_list.sort()
# self.image_name = [x.split('/')[-1].split('.')[0] for x in self.image_list]
# self.label_list = [os.path.join(label_path, x + '_L.png') for x in self.image_list]
# self.fliplr = iaa.Fliplr(0.5)
self.label_info = get_label_info(csv_path)
# resize
# self.resize_label = transforms.Resize(scale, Image.NEAREST)
# self.resize_img = transforms.Resize(scale, Image.BILINEAR)
# normalization
self.to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
# self.crop = transforms.RandomCrop(scale, pad_if_needed=True)
self.image_size = scale
self.scale = [0.5, 1, 1.25, 1.5, 1.75, 2]
self.loss = loss
def __getitem__(self, index):
# load image and crop
seed = random.random()
img = Image.open(self.image_list[index])
img.show()
# random crop image
# =====================================
# w,h = img.size
# th, tw = self.scale
# i = random.randint(0, h - th)
# j = random.randint(0, w - tw)
# img = F.crop(img, i, j, th, tw)
# =====================================
# print(self.scale)
scale = random.choice(self.scale)
# print(scale)
scale = (int(self.image_size[0] * scale), int(self.image_size[1] * scale))
# print(scale)
# print(self.image_size)
# print(scale)
# randomly resize image and random crop
# =====================================
if self.mode == 'train':
img = transforms.Resize(scale, Image.BILINEAR)(img)
img = RandomCrop(self.image_size, seed, pad_if_needed=True)(img)
# =====================================
img = np.array(img)
# # load label
label = Image.open(self.label_list[index])
label.show()
# # crop the corresponding label
# # =====================================
# # label = F.crop(label, i, j, th, tw)
# # =====================================
#
# # randomly resize label and random crop
# # =====================================
if self.mode == 'train':
label = transforms.Resize(scale, Image.NEAREST)(label)
label = RandomCrop(self.image_size, seed, pad_if_needed=True)(label)
# # =====================================
#
label = np.array(label)
#
# augment image and label
# if self.mode == 'train':
# seq_det = self.fliplr.to_deterministic()
# img = seq_det.augment_image(img)
# label = seq_det.augment_image(label)
#
#
# # image -> [C, H, W]
img = Image.fromarray(img)
img = self.to_tensor(img).float()
if self.loss == 'dice':
# label -> [num_classes, H, W]
label = one_hot_it_v11_dice(label, self.label_info).astype(np.uint8)
label = np.transpose(label, [2, 0, 1]).astype(np.float32)
# label = label.astype(np.float32)
label = torch.from_numpy(label)
return img, label
#
elif self.loss == 'crossentropy':
# label = one_hot_it_v11(label, self.label_info).astype(np.uint8)
label = one_hot_it_v11(label,self.label_info)
# label = label.astype(np.float32)
label = torch.from_numpy(label).long()
return img, label
#
def __len__(self):
return len(self.image_list)
if __name__ == '__main__':
path = os.getcwd()
train_path = os.path.join(path,"data\CamVid\\train")
val_path = os.path.join(path,"data\CamVid\\val")
train_labels_path = os.path.join(path,"data\CamVid\\train_labels")
val_labels_path = os.path.join(path,"data\CamVid\\val_labels")
class_dict_path = os.path.join(path,"data\CamVid\class_dict.csv")
# data = CamVid('/path/to/CamVid/train', '/path/to/CamVid/train_labels', '/path/to/CamVid/class_dict.csv', (640, 640))
# data = CamVid(['/data/CamVid/train', '/data/CamVid/val'],
# ['/data/CamVid/train_labels', '/data/CamVid/val_labels'], '/data/CamVid/class_dict.csv',
# (720, 960), loss='crossentropy', mode='val')
data = CamVid([train_path,val_path],
[train_labels_path,val_labels_path],class_dict_path,
(720, 960), loss='dice', mode='val')
data.__getitem__(0)
# from model.build_BiSeNet import BiSeNet
# from utils import reverse_one_hot, get_label_info, colour_code_segmentation, compute_global_accuracy
# print(val_labels_path)
# label_info = get_label_info(class_dict_path)
# print(len(label_info))
# label_info = get_label_info('/data/CamVid/class_dict.csv')
# for i, (img, label) in enumerate(data):
# print(label.size())
# print(torch.max(label))
下面的代码为utils.py
import torch.nn as nn
import torch
from torch.nn import functional as F
from PIL import Image
import numpy as np
import pandas as pd
import random
import numbers
import torchvision
import matplotlib.pyplot as plt
def bool_to_num(num):
number = np.zeros(num.shape)
height = num.shape[0]
width = num.shape[1]
# channel = num.shape[2]
for h in range(height):
for w in range(width):
# for c in range(channel):
if num[h][w]== True:
number[h][w] = 1
else:
number[h][w] = 0
return number
def poly_lr_scheduler(optimizer, init_lr, iter, lr_decay_iter=1,
max_iter=300, power=0.9):
"""Polynomial decay of learning rate
:param init_lr is base learning rate
:param iter is a current iteration
:param lr_decay_iter how frequently decay occurs, default is 1
:param max_iter is number of maximum iterations
:param power is a polymomial power
"""
# if iter % lr_decay_iter or iter > max_iter:
# return optimizer
lr = init_lr*(1 - iter/max_iter)**power
optimizer.param_groups[0]['lr'] = lr
return lr
# return lr
def get_label_info(csv_path):
# return label -> {label_name: [r_value, g_value, b_value, ...}
ann = pd.read_csv(csv_path)
label = {}
for iter, row in ann.iterrows():
label_name = row['name']
r = row['r']
g = row['g']
b = row['b']
class_11 = row['class_11']
label[label_name] = [int(r), int(g), int(b), class_11]
return label
def one_hot_it(label, label_info):
# return semantic_map -> [H, W]
semantic_map = np.zeros(label.shape[:-1])
for index, info in enumerate(label_info):
color = label_info[info]
# colour_map = np.full((label.shape[0], label.shape[1], label.shape[2]), colour, dtype=int)
equality = np.equal(label, color)
class_map = np.all(equality, axis=-1)
semantic_map[class_map] = index
# semantic_map.append(class_map)
# semantic_map = np.stack(semantic_map, axis=-1)
return semantic_map
def one_hot_it_v11(label, label_info):
# return semantic_map -> [H, W, class_num]
semantic_map = np.zeros(label.shape[:-1])
# from 0 to 11, and 11 means void
class_index = 0
for index, info in enumerate(label_info):
color = label_info[info][:3]
class_11 = label_info[info][3]
if class_11 == 1:
# colour_map = np.full((label.shape[0], label.shape[1], label.shape[2]), colour, dtype=int)
equality = np.equal(label, color)
# print(label.shape)
# print(color)
# print(equality.shape)
class_map = np.all(equality, axis=-1)
# print(class_index)
# plt.imshow(class_map)
# plt.show()
# print(class_map.shape)
# # semantic_map[class_map] = index
semantic_map[class_map] = class_index
class_index += 1
else:
equality = np.equal(label, color)
class_map = np.all(equality, axis=-1)
semantic_map[class_map] = 11
return semantic_map
def one_hot_it_v11_dice(label, label_info):
# return semantic_map -> [H, W, class_num]
semantic_map = []
void = np.zeros(label.shape[:2])
dis_all = np.zeros(label.shape[:2])
for index, info in enumerate(label_info):
color = label_info[info][:3]
class_11 = label_info[info][3]
if class_11 == 1:
# colour_map = np.full((label.shape[0], label.shape[1], label.shape[2]), colour, dtype=int)
equality = np.equal(label, color)
# number = bool_to_num(equality)
# plt.imshow(number)
# plt.show()
class_map = np.all(equality, axis=-1)
# number = bool_to_num(class_map)
# plt.imshow(class_map)
# plt.show()
# semantic_map[class_map] = index
dis_all[class_map] = 1
semantic_map.append(class_map)
else:
equality = np.equal(label, color)
class_map = np.all(equality, axis=-1)
# plt.imshow(class_map)
# plt.show()
void[class_map] = 1
# plt.imshow(dis_all)
# plt.show()
# plt.imshow(void)
# plt.show()
semantic_map.append(void)
semantic_map = np.stack(semantic_map, axis=-1).astype(np.float)
return semantic_map
def reverse_one_hot(image):
"""
Transform a 2D array in one-hot format (depth is num_classes),
to a 2D array with only 1 channel, where each pixel value is
the classified class key.
# Arguments
image: The one-hot format image
# Returns
A 2D array with the same width and height as the input, but
with a depth size of 1, where each pixel value is the classified
class key.
"""
# w = image.shape[0]
# h = image.shape[1]
# x = np.zeros([w,h,1])
# for i in range(0, w):
# for j in range(0, h):
# index, value = max(enumerate(image[i, j, :]), key=operator.itemgetter(1))
# x[i, j] = index
image = image.permute(1, 2, 0)
x = torch.argmax(image, dim=-1)
return x
def colour_code_segmentation(image, label_values):
"""
Given a 1-channel array of class keys, colour code the segmentation results.
# Arguments
image: single channel array where each value represents the class key.
label_values
# Returns
Colour coded image for segmentation visualization
"""
# w = image.shape[0]
# h = image.shape[1]
# x = np.zeros([w,h,3])
# colour_codes = label_values
# for i in range(0, w):
# for j in range(0, h):
# x[i, j, :] = colour_codes[int(image[i, j])]
label_values = [label_values[key][:3] for key in label_values if label_values[key][3] == 1]
label_values.append([0, 0, 0])
colour_codes = np.array(label_values)
x = colour_codes[image.astype(int)]
return x
def compute_global_accuracy(pred, label):
pred = pred.flatten()
label = label.flatten()
total = len(label)
count = 0.0
for i in range(total):
if pred[i] == label[i]:
count = count + 1.0
return float(count) / float(total)
def fast_hist(a, b, n):
'''
a and b are predict and mask respectively
n is the number of classes
'''
k = (a >= 0) & (a < n)
return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)
def per_class_iu(hist):
epsilon = 1e-5
return (np.diag(hist) + epsilon) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + epsilon)
class RandomCrop(object):
"""Crop the given PIL Image at a random location.
Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
padding (int or sequence, optional): Optional padding on each border
of the image. Default is 0, i.e no padding. If a sequence of length
4 is provided, it is used to pad left, top, right, bottom borders
respectively.
pad_if_needed (boolean): It will pad the image if smaller than the
desired size to avoid raising an exception.
"""
def __init__(self, size, seed, padding=0, pad_if_needed=False):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
self.padding = padding
self.pad_if_needed = pad_if_needed
self.seed = seed
@staticmethod
def get_params(img, output_size, seed):
"""Get parameters for ``crop`` for a random crop.
Args:
img (PIL Image): Image to be cropped.
output_size (tuple): Expected output size of the crop.
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
"""
random.seed(seed)
w, h = img.size
th, tw = output_size
if w == tw and h == th:
return 0, 0, h, w
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
return i, j, th, tw
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be cropped.
Returns:
PIL Image: Cropped image.
"""
if self.padding > 0:
img = torchvision.transforms.functional.pad(img, self.padding)
# pad the width if needed
if self.pad_if_needed and img.size[0] < self.size[1]:
img = torchvision.transforms.functional.pad(img, (int((1 + self.size[1] - img.size[0]) / 2), 0))
# pad the height if needed
if self.pad_if_needed and img.size[1] < self.size[0]:
img = torchvision.transforms.functional.pad(img, (0, int((1 + self.size[0] - img.size[1]) / 2)))
i, j, h, w = self.get_params(img, self.size, self.seed)
return torchvision.transforms.functional.crop(img, i, j, h, w)
def __repr__(self):
return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
def cal_miou(miou_list, csv_path):
# return label -> {label_name: [r_value, g_value, b_value, ...}
ann = pd.read_csv(csv_path)
miou_dict = {}
cnt = 0
for iter, row in ann.iterrows():
label_name = row['name']
class_11 = int(row['class_11'])
if class_11 == 1:
miou_dict[label_name] = miou_list[cnt]
cnt += 1
return miou_dict, np.mean(miou_list)
class OHEM_CrossEntroy_Loss(nn.Module):
def __init__(self, threshold, keep_num):
super(OHEM_CrossEntroy_Loss, self).__init__()
self.threshold = threshold
self.keep_num = keep_num
self.loss_function = nn.CrossEntropyLoss(reduction='none')
def forward(self, output, target):
loss = self.loss_function(output, target).view(-1)
loss, loss_index = torch.sort(loss, descending=True)
threshold_in_keep_num = loss[self.keep_num]
if threshold_in_keep_num > self.threshold:
loss = loss[loss>self.threshold]
else:
loss = loss[:self.keep_num]
return torch.mean(loss)
def group_weight(weight_group, module, norm_layer, lr):
group_decay = []
group_no_decay = []
for m in module.modules():
if isinstance(m, nn.Linear):
group_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
elif isinstance(m, (nn.Conv2d, nn.Conv3d)):
group_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
elif isinstance(m, norm_layer) or isinstance(m, nn.GroupNorm):
if m.weight is not None:
group_no_decay.append(m.weight)
if m.bias is not None:
group_no_decay.append(m.bias)
assert len(list(module.parameters())) == len(group_decay) + len(
group_no_decay)
weight_group.append(dict(params=group_decay, lr=lr))
weight_group.append(dict(params=group_no_decay, weight_decay=.0, lr=lr))
return weight_group
现在正式讲解代码。
对不起。。。。拖了太多时间。。。。我的锅
首先从主函数开始讲起。
加载训练集,验证集,分类字典的路径。
生成一个CamVid的类。
getitem()主要是配合dataloader()进行数据的读取时应用的。为了说明__getitem__()读取数据的标签时所作的处理,我调用了data.getitem(400)。
下面介绍一下CamVid中的各个函数。
首先是定义函数。
该函数就是把原图像和标签的路径存放在image_list[]和label_list[]中(这里注意,存放的路径,而非图像),label_info是存放分类字典中的数据。transforms.Compose()就是对数据集进行标准化处理。
下面是getitem()函数。
对于这个函数主要讲解以下两个:
if self.loss = ‘dice’:
这里调用了one_hot_it_v11_dice,如果class_11=1,就对标签与当前颜色进行对比
,将颜色相等的,与不相等的记录在一个二维的class_map中,存放在semantic_map中。如果class_11==0,就对标签与所有class_11=0的颜色对比,都存在一个二维数组void中,等循环结束后,再将该void加到semantic_map后。返回一个三维的semantic_map.
if self.loss = ‘crossentropy’:
这里调用了one_hot_it_v11,如果class_11=1,就将semantic_map中的与之对应的位置,每一个物体用一个数值标识。如果class_11=0,将标签图中与该颜色对应的区域的位置上,在semantic_map中设置为11.最后返回一个二维的semantic_map.