pytorch图像处理:读取数据集Dataset和ImageFolder

 1、重写Dataset类:

#源码
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
 
#这个函数就是根据索引,迭代的读取路径和标签。因此我们需要有一个路径和标签的 ‘容器’供我们读
def __getitem__(self, index):
	raise NotImplementedError
 
#返回数据的长度
def __len__(self):
	raise NotImplementedError
def __add__(self, other):
	return ConcatDataset([self, other])

想制作自己的图像数据集供DataLoader拿取,首先要重写Datasets类,主要用来完成从哪里读取数据和标签的功能。主要是__getitem()__(返回数据集和标签)和__len__(返回数据的长度)这两个方法。

完成Datasets类的这两个主要功能后,训练的时候可以把数据集传送给DataLoader就可以获取自己想要的batch数据。

例1:通过包含 数据路径 和 标签 的TXT文件读取

# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset
 
#集成Dataset类
class MyDataset(Dataset):
def __init__(self, txt_path, transform = None, target_transform = None):
    """
    tex_path : txt文本路径,该文本包含了图像的路径信息,以及标签信息
    transform:数据处理,对图像进行随机剪裁,以及转换成tensor
    """
	fh = open(txt_path, 'r')  #读取文件
	imgs = []  #用来存储路径与标签
    #一行一行的读取
	for line in fh:
		line = line.rstrip()  #这一行就是图像的路径,以及标签  
        
		words = line.split()
		imgs.append((words[0], int(words[1])))  #路径和标签添加到列表中
		self.imgs = imgs                        
		self.transform = transform
		self.target_transform = target_transform
 
def __getitem__(self, index):
	fn, label = self.imgs[index]   #通过index索引返回一个图像路径fn 与 标签label
	img = Image.open(fn).convert('RGB')  #把图像转成RGB
	if self.transform is not None:
		img = self.transform(img) 
	return img, label              #这就返回一个样本
 
def __len__(self):
	return len(self.imgs)          #返回长度,index就会自动的指导读取多少
 

例2:通过标签文件读取

#首先集成Dataset这个类
class DealDataset(Dataset):
    """
        下载数据、初始化数据,都可以在这里完成
    """
    def __init__(self):
 
        #这里xy 就是一个容器,通过读取一个包含有数据和标签信息的文件
        xy = np.loadtxt('../dataSet/diabetes.csv.gz', delimiter=',', dtype=np.float32)
 
        self.x_data = torch.from_numpy(xy[:, 0:-1])
        self.y_data = torch.from_numpy(xy[:, [-1]])
        
        #长度,可以给__len__返回用。
        self.len = xy.shape[0]
    
    def __getitem__(self, index):
        
        #通过索引index,索引到指定的数据以及对应的标签
        return self.x_data[index], self.y_data[index]
 
    def __len__(self):
        return self.len
 

例3:没有标签文件,代码根据文件夹分类自己构造

class RMBDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        rmb面额分类任务的Dataset
        :param data_dir: str, 数据集所在路径
        :param transform: torch.transform,数据预处理
        """
        self.label_name = {"1": 0, "100": 1}
        self.data_info = self.get_img_info(data_dir)  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
        self.transform = transform
 
    def __getitem__(self, index):
     
        path_img, label = self.data_info[index]       #索引读取图像路径和标签
        img = Image.open(path_img).convert('RGB')     # 读取图像,返回Image 类型 0~255
 
        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,把图像转为tensor等等
 
        return img, label
 
    def __len__(self):
        return len(self.data_info)
 
    @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
 
                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = rmb_label[sub_dir]
                    data_info.append((path_img, int(label)))
 
        return data_info  ##返回的也就是图像路径 和 标签

2、文件夹读取ImageFolder

# 预处理 转为tensor 以及 标准化
transform = transform.Compose([transform.ToTensor(), transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
 
#使用torchvision.datasets.ImageFolder读取数据集 指定train 和 test文件夹
traindata = torchvision.datasets.ImageFolder('data/rmb_split/train/', transform=transform)
trainloader = torch.utils.data.DataLoader(traindata, batch_size=4, shuffle=True, num_workers=1)
 
 
testset = torchvision.datasets.ImageFolder('data/rmb_split/test/', transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=True, num_workers=1)
 

 

### 回答1: 在 PyTorch读取自定义数据集的一般步骤如下: 1. 定义数据集类:首先需要定义一个数据集类,继承自 `torch.utils.data.Dataset` 类,并实现 `__getitem__` 和 `__len__` 方法。在 `__getitem__` 方法中,根据索引返回一个样本的数据和标签。 2. 加载数据集:使用 `torch.utils.data.DataLoader` 类加载数据集,可以设置批量大小、多线程读取数据等参数。 下面是一个简单的示例代码,演示如何使用 PyTorch 读取自定义数据集: ```python import torch from torch.utils.data import Dataset, DataLoader class CustomDataset(Dataset): def __init__(self, data, targets): self.data = data self.targets = targets def __getitem__(self, index): x = self.data[index] y = self.targets[index] return x, y def __len__(self): return len(self.data) # 加载训练集和测试集 train_data = ... train_targets = ... train_dataset = CustomDataset(train_data, train_targets) train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) test_data = ... test_targets = ... test_dataset = CustomDataset(test_data, test_targets) test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False) # 训练模型 for epoch in range(num_epochs): for batch_idx, (data, targets) in enumerate(train_loader): # 前向传播、反向传播,更新参数 ... ``` 在上面的示例代码中,我们定义了一个 `CustomDataset` 类,加载了训练集和测试集,并使用 `DataLoader` 类分别对它们进行批量读取。在训练模型时,我们可以像使用 PyTorch 自带的数据集一样,循环遍历每个批次的数据和标签,进行前向传播、反向传播等操作。 ### 回答2: PyTorch是一个开源的深度学习框架,它提供了丰富的功能用于读取和处理自定义数据集。下面是一个简单的步骤来读取自定义数据集。 首先,我们需要定义一个自定义数据集类,该类应继承自`torch.utils.data.Dataset`类,并实现`__len__`和`__getitem__`方法。`__len__`方法应返回数据集的样本数量,`__getitem__`方法根据给定索引返回一个样本。 ```python import torch from torch.utils.data import Dataset class CustomDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data[idx] return torch.tensor(sample) ``` 接下来,我们可以创建一个数据集实例并传入自定义数据。假设我们有一个包含多个样本的列表 `data`。 ```python data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] dataset = CustomDataset(data) ``` 然后,我们可以使用`torch.utils.data.DataLoader`类加载数据集,并指定批次大小、是否打乱数据等。 ```python batch_size = 2 dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) ``` 现在,我们可以迭代数据加载器来获取批次的样本。 ```python for batch in dataloader: print(batch) ``` 上面的代码将打印出两个批次的样本。如果`shuffle`参数设置为`True`,则每个批次的样本将是随机的。 总而言之,PyTorch提供了简单而强大的工具来读取和处理自定义数据集,可以根据实际情况进行适当修改和扩展。 ### 回答3: PyTorch是一个流行的深度学习框架,可以用来训练神经网络模型。要使用PyTorch读取自定义数据集,可以按照以下几个步骤进行: 1. 准备数据集:将自定义数据集组织成合适的目录结构。通常情况下,可以将数据集分为训练集、验证集和测试集,每个集合分别放在不同的文件夹中。确保每个文件夹中的数据按照类别进行分类,以便后续的标签处理。 2. 创建数据加载器:在PyTorch中,数据加载器是一个有助于有效读取和处理数据的类。可以使用`torchvision.datasets.ImageFolder`类创建一个数据加载器对象,通过传入数据集的目录路径来实现。 3. 数据预处理:在将数据传入模型之前,可能需要对数据进行一些预处理操作,例如图像变换、标准化或归一化等。可以使用`torchvision.transforms`中的类来实现这些预处理操作,然后将它们传入数据加载器中。 4. 创建数据迭代器:数据迭代器是连接数据集和模型的重要接口,它提供了一个逐批次加载数据的功能。可以使用`torch.utils.data.DataLoader`类创建数据迭代器对象,并设置一些参数,例如批量大小、是否打乱数据等。 5. 使用数据迭代器:在训练时,可以使用Python的迭代器来遍历数据集并加载数据。通常,它会在每个迭代步骤中返回一个批次的数据和标签。可以通过`for`循环来遍历数据迭代器,并在每个步骤中处理批次数据和标签。 这样,我们就可以在PyTorch中成功读取并处理自定义数据集。通过这种方式,我们可以更好地利用PyTorch的功能来训练和评估自己的深度学习模型。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值