Pytorch数据集自定义读取

以读取VOC2012语义分割数据集为例,具体见代码注释:

VocDataset.py

from PIL import Image
import torch
import torch.utils.data as data
import numpy as np
import os
import torchvision
import torchvision.transforms as transforms
import time

#VOC数据集分类对应颜色标签
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
                [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
                [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
                [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
                [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
                [0, 64, 128]]

#颜色标签空间转到序号标签空间,就他妈这里浪费巨量的时间,这里还他妈的有问题
def voc_label_indices(colormap, colormap2label):
    """Assign label indices for Pascal VOC2012 Dataset."""
    idx = ((colormap[:, :, 2] * 256 + colormap[ :, :,1]) * 256+ colormap[:, :,0])
    #out = np.empty(idx.shape, dtype = np.int64) 
    out = colormap2label[idx]
    out=out.astype(np.int64)#数据类型转换
    end = time.time()
    return out

class MyDataset(data.Dataset):#创建自定义的数据读取类
    def __init__(self, root, is_train, crop_size=(320,480)):
        self.rgb_mean =(0.485, 0.456, 0.406)
        self.rgb_std = (0.229, 0.224, 0.225)
        self.root=root
        self.crop_size=crop_size
        images = []#创建空列表存文件名称
        txt_fname = '%s/ImageSets/Segmentation/%s' % (root, 'train.txt' if is_train else 'val.txt')
        with open(txt_fname, 'r') as f:
            self.images = f.read().split()
        #数据名称整理
        self.files = []
        for name in self.images:
            img_file = os.path.join(self.root, "JPEGImages/%s.jpg" % name)
            label_file = os.path.join(self.root, "SegmentationClass/%s.png" % name)
            self.files.append({
                "img": img_file,
                "label": label_file,
                "name": name
            })
        self.colormap2label = np.zeros(256**3)
        #整个循环的意思就是将颜色标签映射为单通道的数组索引
        for i, cm in enumerate(VOC_COLORMAP):
            self.colormap2label[(cm[2] * 256 + cm[1]) * 256 + cm[0]] = i
    #按照索引读取每个元素的具体内容
    def __getitem__(self, index):
        
        datafiles = self.files[index]
        name = datafiles["name"]
        image = Image.open(datafiles["img"])
        label = Image.open(datafiles["label"]).convert('RGB')#打开的是PNG格式的图片要转到rgb的格式下,不然结果会比较要命
        #以图像中心为中心截取固定大小图像,小于固定大小的图像则自动填0
        imgCenterCrop = transforms.Compose([
             transforms.CenterCrop(self.crop_size),
             transforms.ToTensor(),
             transforms.Normalize(self.rgb_mean, self.rgb_std),#图像数据正则化
         ])
        labelCenterCrop = transforms.CenterCrop(self.crop_size)
        cropImage=imgCenterCrop(image)
        croplabel=labelCenterCrop(label)
        croplabel=torch.from_numpy(np.array(croplabel)).long()#把标签数据类型转为torch
       
        #将颜色标签图转为序号标签图
        mylabel=voc_label_indices(croplabel, self.colormap2label)
       
        return cropImage,mylabel
    #返回图像数据长度
    def __len__(self):
        return len(self.files)

Train.py

import matplotlib.pyplot as plt
import torch.utils.data as data
import torchvision.transforms as transforms
import numpy as np

from PIL import Image
from VocDataset import MyDataset

#VOC数据集分类对应颜色标签
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
                [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
                [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
                [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
                [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
                [0, 64, 128]]

root='../data/VOCdevkit/VOC2012'
train_data=MyDataset(root,True)
trainloader = data.DataLoader(train_data, 4)

#从数据集中拿出一个批次的数据
for i, data in enumerate(trainloader):
    getimgs, labels= data
    img = transforms.ToPILImage()(getimgs[0])

    labels = labels.numpy()#tensor转numpy
    labels=labels[0]#获得批次标签集中的一张标签图像
    labels = labels.transpose((1,0))#数组维度切换,将第1维换到第0维,第0维换到第1维

    ##将单通道索引标签图片映射回颜色标签图片
    newIm= Image.new('RGB', (480, 320))#创建一张与标签大小相同的图片,用以显示标签所对应的颜色
    for i in range(0, 480):
        for j in range(0, 320):
            sele=labels[i][j]#取得坐标点对应像素的值
            newIm.putpixel((i, j), (int(VOC_COLORMAP[sele][0]), int(VOC_COLORMAP[sele][1]), int(VOC_COLORMAP[sele][2])))

    #显示图像和标签
    plt.figure("image")
    ax1 = plt.subplot(1,2,1)
    ax2 = plt.subplot(1,2,2)
    plt.sca(ax1)
    plt.imshow(img)
    plt.sca(ax2)
    plt.imshow(newIm)
    plt.show()
  • 3
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
### 回答1: 在 PyTorch读取自定义数据集的一般步骤如下: 1. 定义数据集类:首先需要定义一个数据集类,继承自 `torch.utils.data.Dataset` 类,并实现 `__getitem__` 和 `__len__` 方法。在 `__getitem__` 方法中,根据索引返回一个样本的数据和标签。 2. 加载数据集:使用 `torch.utils.data.DataLoader` 类加载数据集,可以设置批量大小、多线程读取数据等参数。 下面是一个简单的示例代码,演示如何使用 PyTorch 读取自定义数据集: ```python import torch from torch.utils.data import Dataset, DataLoader class CustomDataset(Dataset): def __init__(self, data, targets): self.data = data self.targets = targets def __getitem__(self, index): x = self.data[index] y = self.targets[index] return x, y def __len__(self): return len(self.data) # 加载训练集和测试集 train_data = ... train_targets = ... train_dataset = CustomDataset(train_data, train_targets) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) test_data = ... test_targets = ... test_dataset = CustomDataset(test_data, test_targets) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) # 训练模型 for epoch in range(num_epochs): for batch_idx, (data, targets) in enumerate(train_loader): # 前向传播、反向传播,更新参数 ... ``` 在上面的示例代码中,我们定义了一个 `CustomDataset` 类,加载了训练集和测试集,并使用 `DataLoader` 类分别对它们进行批量读取。在训练模型时,我们可以像使用 PyTorch 自带的数据集一样,循环遍历每个批次的数据和标签,进行前向传播、反向传播等操作。 ### 回答2: PyTorch是一个开源的深度学习框架,它提供了丰富的功能用于读取和处理自定义数据集。下面是一个简单的步骤来读取自定义数据集。 首先,我们需要定义一个自定义数据集类,该类应继承自`torch.utils.data.Dataset`类,并实现`__len__`和`__getitem__`方法。`__len__`方法应返回数据集的样本数量,`__getitem__`方法根据给定索引返回一个样本。 ```python import torch from torch.utils.data import Dataset class CustomDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data[idx] return torch.tensor(sample) ``` 接下来,我们可以创建一个数据集实例并传入自定义数据。假设我们有一个包含多个样本的列表 `data`。 ```python data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] dataset = CustomDataset(data) ``` 然后,我们可以使用`torch.utils.data.DataLoader`类加载数据集,并指定批次大小、是否打乱数据等。 ```python batch_size = 2 dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) ``` 现在,我们可以迭代数据加载器来获取批次的样本。 ```python for batch in dataloader: print(batch) ``` 上面的代码将打印出两个批次的样本。如果`shuffle`参数设置为`True`,则每个批次的样本将是随机的。 总而言之,PyTorch提供了简单而强大的工具来读取和处理自定义数据集,可以根据实际情况进行适当修改和扩展。 ### 回答3: PyTorch是一个流行的深度学习框架,可以用来训练神经网络模型。要使用PyTorch读取自定义数据集,可以按照以下几个步骤进行: 1. 准备数据集:将自定义数据集组织成合适的目录结构。通常情况下,可以将数据集分为训练集、验证集和测试集,每个集合分别放在不同的文件夹中。确保每个文件夹中的数据按照类别进行分类,以便后续的标签处理。 2. 创建数据加载器:在PyTorch中,数据加载器是一个有助于有效读取和处理数据的类。可以使用`torchvision.datasets.ImageFolder`类创建一个数据加载器对象,通过传入数据集的目录路径来实现。 3. 数据预处理:在将数据传入模型之前,可能需要对数据进行一些预处理操作,例如图像变换、标准化或归一化等。可以使用`torchvision.transforms`中的类来实现这些预处理操作,然后将它们传入数据加载器中。 4. 创建数据迭代器:数据迭代器是连接数据集和模型的重要接口,它提供了一个逐批次加载数据的功能。可以使用`torch.utils.data.DataLoader`类创建数据迭代器对象,并设置一些参数,例如批量大小、是否打乱数据等。 5. 使用数据迭代器:在训练时,可以使用Python的迭代器来遍历数据集并加载数据。通常,它会在每个迭代步骤中返回一个批次的数据和标签。可以通过`for`循环来遍历数据迭代器,并在每个步骤中处理批次数据和标签。 这样,我们就可以在PyTorch中成功读取并处理自定义数据集。通过这种方式,我们可以更好地利用PyTorch的功能来训练和评估自己的深度学习模型。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值