【pytorch】定义自己的dataloader

在使用自己数据集训练网络时,往往需要定义自己的dataloader。这里用最简单的例子做个记录。

定义datalaoder

一般将dataloader封装为一个类,这个类继承自torch.utils.data.dataset

from torch.utils.data import dataset

class LoadData(dataset.Dataset):  # 注意父类的名称,不能写dataset
    pass

需要注意的是dataset是模块名,而Dataset是类名,在python中模块名和类名是完全独立的命名空间,因此这里的父类需要写成dataset.Dataset


在我们定义的LoadData中,至少需要有三个方法:

  • __init__方法,主要用来定义数据的预处理
  • __getitem__方法,返回数据的item和label
  • __len__方法,返回数据个数
from torch.utils.data import dataset

class LoadData(dataset.Dataset):
    
    def __init__(self):
        super(LoadData, self).__init__()
        pass
    
    def __getitem__(self):
        pass
    
    def __len__(self):
        pass

__init__方法需要传入至少两个参数:

  • 一般数据的地址和标签已经被保存在某个文档中了(这里是txt格式的文档)。因此需要传入这个文档的地址。
  • 因为__init__方法要做预处理,一般用来train的预处理和test的预处理是不同的,因此需要区分二者的参数。
	def __init__(self, txt_path, train=True):
        super(LoadData, self).__init__()
        self.img_info = self.get_img(txt_path)
        self.train = train

		# train预处理
        self.train_transforms = transforms.Compose([
            transforms.Resize(20),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])

		# test预处理
        self.test_transforms = transforms.Compose([
            transforms.Resize(20),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])

	# 这个函数是用来读txt文档的
    def get_img(self, txt_path):
        with open(txt_path, 'r', encoding='utf-8') as f:
            imgs_info = f.readlines()
            imgs_info = list(map(lambda x:x.strip().split('\t'), imgs_info))
            return imgs_info

__getitem__方法只需要根据index返回数据的item和label。

    def __getitem__(self, index):
        img_path, label = self.img_info[index]
        img = Image.open(img_path)
        label = int(label)
        
        # 注意区分预处理
        if self.train:
            img = self.train_transforms(img)
        else:
            img = self.test_transforms(img)

        return img, label

__len__方法最简单,仅返回数据项个数。

    def __len__(self):
        return len(self.img_info)

调用dataloader

以训练数据为例,调用dataloader需要两步:

  • 将自定义的LoadData实例化
  • 传入torch.utils.data.dataloader
from torch.utils.data import dataloader

train_dataset = LoadData(txt_path='XXXX', train=True)

train_loader = dataloader.Dataloader(
	dataset=train_dataset,
    batch_size=8,
    shuffle=True
	)

至此,一个最简单的dataloader就完成了!

可以用以下代码测试:

for image, label in train_loader:
	print(image.shape)
    print(label)
### 回答1: 可以使用PyTorch的torch.utils.data.DataLoader类查看dataloader类的相关信息,具体方法是通过在Python环境中导入dataloader类并调用dataloader.info()函数来查看dataloader类的信息。 ### 回答2: 要查看PyTorchDataLoader类的信息,可以通过以下步骤实现: 1. 首先,导入所需的PyTorch库: import torch from torch.utils.data import DataLoader 2. 创建你的数据集对象,例如一个自定义Dataset类: class MyDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] 3. 使用上一步创建的数据集对象来初始化DataLoader对象: data = [1, 2, 3, 4, 5] dataset = MyDataset(data) dataloader = DataLoader(dataset, batch_size=2, shuffle=True) 在这个例子中,我们定义了一个包含5个元素的数据列表,然后创建了一个自定义的数据集对象并将其传递给DataLoader构造函数。 4. 现在,你可以打印DataLoader对象的一些属性信息,如下所示: print("Batch size:", dataloader.batch_size) print("Shuffle:", dataloader.shuffle) print("Number of workers:", dataloader.num_workers) print("Total batches:", len(dataloader)) 这将打印出DataLoader对象的批量大小、是否进行洗牌、工作线程数量以及总批次数等信息。 5. 此外,你还可以迭代DataLoader对象来访问批次数据,例如: for batch in dataloader: print(batch) 这将迭代生成数据集的批次,你可以在每个批次中进行进一步的处理。 总之,通过创建你的自定义数据集对象并传递给DataLoader构造函数,你可以获取DataLoader对象的相关信息,如批量大小、是否洗牌和工作线程数量等。此外,你还可以迭代DataLoader对象以访问数据集的批次数据。 ### 回答3: 在PyTorch中,可以使用`DataLoader`类来加载数据。要查看`DataLoader`类的信息,可以通过以下步骤进行: 首先,导入所需的库: ```python import torch from torch.utils.data import DataLoader ``` 接下来,创建自定义的数据集类和数据加载器: ```python class CustomDataset(torch.utils.data.Dataset): def __init__(self): # 初始化数据集 pass def __len__(self): # 返回数据集大小 pass def __getitem__(self, index): # 返回指定索引处的数据 pass dataset = CustomDataset() dataloader = DataLoader(dataset, batch_size=32, shuffle=True) ``` 在以上代码中,我们首先创建了一个自定义的数据集类`CustomDataset`,并实现了`__len__`和`__getitem__`方法来获取数据集大小和指定索引处的数据。 然后,我们使用`DataLoader`类将数据集加载到数据加载器`dataloader`中。在`DataLoader`类的构造函数中传入`dataset`对象,并指定每个批次的大小为32,并设置`shuffle=True`来打乱数据顺序。 要查看`DataLoader`类的信息,可以使用`print`语句打印相关信息: ```python print(dataloader) # 输出结果类似于: # <torch.utils.data.dataloader.DataLoader object at 0x7f8c9dd46c90> ``` 通过打印`dataloader`对象,我们可以看到其类别和内存地址等信息。 另外,还可以使用`for`循环迭代数据加载器,并打印每个批次的数据: ```python for data in dataloader: print(data) # 输出结果类似于: # tensor([[1, 2, 3, ...], [4, 5, 6, ...], ...]) # tensor([[7, 8, 9, ...], [10, 11, 12, ...], ...]) # ... ``` 以上代码会迭代输出每个批次的数据。每个批次都是一个`tensor`对象,其中包含了对应的数据。 通过使用以上方法,我们可以查看`DataLoader`类的信息,包括对象的类别和内存地址,以及每个批次的数据。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值