自定义Dataset,处理数据。

自定义Dataset

一、原因

在实际情况下,数据集都需要自己去处理,并不是把data和label已经整理好直接能放到Dataset,DataLoader就能训练的。举例一个简单的情况:在分类任务中label可能是由不同字符串代表不同的类别比如"A","B"代表不同的类别。很明显这样的label是不能参与训练的,所以我们需要把他们对应到相应的分类数字上。这样的处理我们放在Dataset中完成。

二、导包

import torch.utils.data.dataset as Dataset
import numpy as np

三、Dataset

一、准备数据

创建labels`, 1000长度的由'A', 'B', 'C', 'D'随即组成的列表

import random
characters = ['A', 'B', 'C', 'D']

labels = [random.choice(characters) for _ in range(1000)]
print(labels[:10])

查看前五个label
在这里插入图片描述

创建datas

datas = np.random.randn(1000, 12)
print(datas[0:5])

查看前五个data
在这里插入图片描述

二、自定义Dataset

自定义Dataset

class MyDataset(Dataset):
    def __init__(self, datas, label_list):
      print(datas[0:5])
      self.datas = datas
      self.labels = []

      self.label_list = label_list
      keys = list(set([y for y in self.label_list]))
      keys.sort()
      dictkeys = {key: ii for ii, key in enumerate(keys)}

      for i in range(len(self.label_list)):
        self.labels.append(dictkeys[label_list[i]])

    def __len__(self):
      return len(self.labels)

    def __getitem__(self, idx):
      return torch.FloatTensor(self.datas[idx]), self.labels[idx]

这段是将label转为数字分类的关键,将未处理的label一个个存到到列表中,用集合确保每个类别对应的key只有一个,然后sort()按字典排序,enumerate使得ii, key分别对应索引和key,然后组成一个字典使得每一个key对应一个索引,这索引也就是之后的数字类别。

      keys = list(set([y for y in self.label_list]))
      keys.sort()
      dictkeys = {key: ii for ii, key in enumerate(keys)}

然后一个个读取未处理的label放入字典中也就得到了对应的数字类别。
在这里插入图片描述

      for i in range(len(self.label_list)):
        self.labels.append(dictkeys[label_list[i]])

很容易理解这就是代表整个数据的大小。

def __len__(self):
      return len(self.labels)

这是返回data和label,idx表示是索引,所以返回这个索引对应的data和label

    def __getitem__(self, idx):
      return torch.FloatTensor(self.datas[idx]), self.labels[idx]

三、结果

创建Dataset和DataLoader对象,batch_size=5表示每一次迭代器取5对数据,shuffle=False表示不打乱数据,为了方便之后观察我们的处理情况,我后面也会出示True的情况。

mydataset = MyDataset(datas, labels)
dataloader = DataLoader(mydataset, batch_size=5, shuffle=False)

取一次迭代的数据

for idx, (data, label) in enumerate(dataloader):
  print(data)
  print(label)
  break

观察结果,实际情况与预期情况相同。
在这里插入图片描述
在这里插入图片描述

四、小结

大家肯定会好奇__getitem__的目的是什么,其实是我下一篇介绍自定义DataLoader涉及的,DataLoader中的batch_size其实就意味着getitem几次。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值