语义分割数据集:CamVid数据集的创建和使用-pytorch

很多人反映进不去CamVid官网,这里放上处理过的数据集下载链接:

链接:https://pan.baidu.com/s/1Kk_t-EugzyZdJuesDaFHQA?pwd=yumi 
提取码:yumi 


本文主要介绍CamVid数据集的使用方法,其他内容不加赘述。

下载地址:CamVid官网

下载方式

从上图连接中下载,labeled images,共701张图像。

处理数据集

下载完成后,共有两个文件夹,分别存储了图像image和标签label。

code.txt总保存了标签的名字信息。

 需要注意的是!标签文件名比原始的图像文件名多了一个 _P。所以需要先把名字处理成一致。

1.使用以下代码,将label文件改名。

import os,sys

cur_path = r'dataset/camvid/labels/' # 你的数据集路径

labels = os.listdir(cur_path)

for label in labels:
    old_label = str(label)
    new_label = label.replace('_P.png','.png')
    print(old_label, new_label)
    os.rename(os.path.join(cur_path,old_label),os.path.join(cur_path,new_label))
    

2.将数据集按照7:3的比例随机分为2部分,分别为训练集和测试集。将两部分的文件加入对应的文件夹中。

import os
import random
import shutil

# 数据集路径
dataset_path = r'dataset\camvid1'
images_path = r'dataset\camvid1/images'
labels_path   = r'dataset\camvid1/labels'

images_name = os.listdir(images_path)
images_num  = len(images_name)
alpha  = int( images_num  * 0.7 )
print(images_num)

random.shuffle(images_name)
train_list = images_name[0:alpha]
valid_list = images_name[alpha:1]

# 确认分割正确
print('train list: ',len(train_list))
print('valid list: ',len(valid_list))

# 创建train,valid和test的文件夹
train_images_path = os.path.join(dataset_path,'train_images')
train_labels_path  = os.path.join(dataset_path,'train_labels')
if os.path.exists(train_images_path)==False:
    os.mkdir(train_images_path )
if os.path.exists(train_labels_path)==False:
    os.mkdir(train_labels_path)

valid_images_path = os.path.join(dataset_path,'valid_images')
valid_labels_path  = os.path.join(dataset_path,'valid_labels')
if os.path.exists(valid_images_path)==False:
    os.mkdir(valid_images_path )
if os.path.exists(valid_labels_path)==False:
    os.mkdir(valid_labels_path)

# 拷贝影像到指定目录
for image in train_list:
    shutil.copy(os.path.join(images_path,image), os.path.join(train_images_path,image))
    shutil.copy(os.path.join(labels_path,image), os.path.join(train_labels_path,image))

for image in valid_list:
    shutil.copy(os.path.join(images_path,image), os.path.join(valid_images_path,image))
    shutil.copy(os.path.join(labels_path,image), os.path.join(valid_labels_path,image))

得到如下结果

 Dataset和DataLoader

# 导入库
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
import os.path as osp
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
#用于做可视化
Cam_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
                [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
                [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
                [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
                [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
                [0, 64, 128], [0, 32, 128],[0, 16, 128],[0, 64, 64],[0, 64, 32],
                [0, 64, 16],[64, 64, 128],[0, 32, 16],[32,32,32],[16,16,16],[32,16,128],
                [192,16,16],[32,32,196],[192,32,128], [25,15,125],[32,124,23],[111,222,113],
               ]

#32类
Cam_CLASSES = ['Animal','Archway','Bicyclist','Bridge','Building','Car','CartLuggagePram','Child',
               'Column_Pole','Fence','LaneMkgsDriv','LaneMkgsNonDriv','Misc_Text','MotorcycleScooter',
               'OtherMoving','ParkingBlock','Pedestrian','Road','RoadShoulder','Sidewalk','SignSymbol',
               'Sky', 'SUVPickupTruck','TrafficCone','TrafficLight', 'Train','Tree','Truck_Bus', 'Tunnel',
               'VegetationMisc', 'Void','Wall']


torch.manual_seed(17)
# 自定义数据集CamVidDataset
class CamVidDataset(torch.utils.data.Dataset):
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    """
    
    def __init__(self, images_dir, masks_dir):
        self.transform = A.Compose([
            A.Resize(448, 448),
            A.HorizontalFlip(),
            A.VerticalFlip(),
            A.Normalize(),
            ToTensorV2(),
        ]) 
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]

    
    def __getitem__(self, i):
        # read data
        image = np.array(Image.open(self.images_fps[i]).convert('RGB'))
        mask = np.array( Image.open(self.masks_fps[i]).convert('RGB'))
        image = self.transform(image=image,mask=mask)
        
        return image['image'], image['mask'][:,:,0]
        
    def __len__(self):
        return len(self.ids)
    
    
# 设置数据集路径
DATA_DIR = r'dataset\camvid' # 根据自己的路径来设置
x_train_dir = os.path.join(DATA_DIR, 'train_images')
y_train_dir = os.path.join(DATA_DIR, 'train_labels')
x_valid_dir = os.path.join(DATA_DIR, 'valid_images')
y_valid_dir = os.path.join(DATA_DIR, 'valid_labels')
    

train_dataset = CamVidDataset(
    x_train_dir, 
    y_train_dir, 
)
val_dataset = CamVidDataset(
    x_valid_dir, 
    y_valid_dir, 
)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)

 可以用下面的代码看一下结果

for index, (img, label) in enumerate(train_loader):
    print(img.shape)
    print(label.shape)
    
    plt.figure(figsize=(10,10))
    plt.subplot(221)
    plt.imshow((img[0,:,:,:].moveaxis(0,2)))
    plt.subplot(222)
    plt.imshow(label[0,:,:])
    
    plt.subplot(223)
    plt.imshow((img[6,:,:,:].moveaxis(0,2)))
    plt.subplot(224)
    plt.imshow(label[6,:,:])
    
    plt.show() 
    if index==0:
        break

其他

如果打不开网站或觉得处理数据比较麻烦,可以去kaggle下载已经处理好的Camvid数据集。链接如下:

Camvid dataset | Kaggle

 需要注意的是,这个数据集划分为训练集、验证集和测试集。需要自行修改相关代码使用。

  • 19
    点赞
  • 130
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 19
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 19
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

yumaomi

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值