DataSet类(1)

DataSet类(1)

一、DataSet类

**原文链接:**https://zhuanlan.zhihu.com/p/500839903

DataSet类,可以帮助我们提取我们所需要的数据,用子类继承的DataSet类,给每一个数据进行编号(idx),在后面的神经网络中,初始化DataSet子类的实例后,就可以通过编号去实例对象中读取相应的数据,会自动调用__getitem__方法。同时子类对象也会获得相应真实的Label

**DataSet作用:**提供一种方式去获取数据及其对应真实的Label

DataSet类的子类,需要重写的函数:

获取每一个数据、以及其对应的Label
统计数据集中的数据数量    
1.DataSet类官方解读

在pycharm终端中输入:
在这里插入图片描述

得到以下输出:

Help on class Dataset in module torch.utils.data.dataset:

class Dataset(typing.Generic)
 |  An abstract class representing a :class:`Dataset`.
 |  
 |  All datasets that represent a map from keys to data samples should subclass
 |  it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
 |  data sample for a given key. Subclasses could also optionally overwrite
 |  :meth:`__len__`, which is expected to return the size of the dataset by many
 |  :class:`~torch.utils.data.Sampler` implementations and the default options
 |  of :class:`~torch.utils.data.DataLoader`.
 |  
 |  .. note::
 |    :class:`~torch.utils.data.DataLoader` by default constructs a index
 |    sampler that yields integral indices.  To make it work with a map-style
 |    dataset with non-integral indices/keys, a custom sampler must be provided.
 |  
 |  Method resolution order:
 |      Dataset
 |      typing.Generic
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __add__(self, other:'Dataset[T_co]') -> 'ConcatDataset[T_co]'
 |  
 |  __getattr__(self, attribute_name)
 |  
 |  __getitem__(self, index) -> +T_co
 |  
 |  ----------------------------------------------------------------------
 |  Class methods defined here:
 |  
 |  register_datapipe_as_function(function_name, cls_to_register, enable_df_api_tracing=False) from typing.GenericMeta
 |  
 |  register_function(function_name, function) from typing.GenericMeta
 |  
 |  ----------------------------------------------------------------------
 |  Data descriptors defined here:
 |  
 |  __dict__
 |      dictionary for instance variables (if defined)
 |  
 |  __weakref__
 |      list of weak references to the object (if defined)
 |  
 |  ----------------------------------------------------------------------
 |  Data and other attributes defined here:
 |  
 |  __abstractmethods__ = frozenset()
 |  
 |  __annotations__ = {'functions': typing.Dict[str, typing.Callable]}
 |  
 |  __args__ = None
 |  
 |  __extra__ = None
 |  
 |  __next_in_mro__ = <class 'object'>
 |      The most base type
 |  
 |  __orig_bases__ = (typing.Generic[+T_co],)
 |  
 |  __origin__ = None
 |  
 |  __parameters__ = (+T_co,)
 |  
 |  __tree_hash__ = -9223371872509358054
 |  
 |  functions = {'concat': functools.partial(<function Dataset.register_da...
 |  
 |  ----------------------------------------------------------------------
 |  Static methods inherited from typing.Generic:
 |  
 |  __new__(cls, *args, **kwds)
 |      Create and return a new object.  See help(type) for accurate signature.

该类是一个抽象类,所有的数据集想要在数据与标签之间建立映射,都需要继承这个类,所有的子类都需要重写__getitem__方法,该方法根据索引值获取每一个数据并且获取其对应的Label,子类也可以重写__len__方法,返回数据集的size大小

二、蚂蚁蜜蜂分类数据集:

https://download.pytorch.org/tutorial/hymenoptera_data.zip

代码.py

from torch.utils.data import Dataset
from PIL import Image
import os

class MyData(Dataset):
    def __init__(self,root_dir,label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        #  将路径进行合并
        self.path = os.path.join(self.root_dir,self.label_dir)
        # os.listdir(self.path) 将路径下的文件 装入列表进行返回
        self.img_path = os.listdir(self.path)


    def __getitem__(self, index):  # idx是index的简称,就是一个编号,以便以后数据集获取后,我们使用索引编号访问每个数据
        img_name = self.img_path[index]  # 按照索引取得列表下的文件名字
        # 路径合并,得到对应索引的文件路径
        img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
        # 获得路径下的图片
        img = Image.open(img_item_path)
        label = self.label_dir
        return img,label

    def __len__(self):
        # 得到数据集的长度
        return len(self.img_path)


if __name__ == '__main__':
    root_dir = r"/home/zxz/DEEPLEARNING/DEMO/dataset_1/hymenoptera_data/train"
    ants_label_dir = "ants"
    bees_label_dir = "bees"
    # 得到ants数据集
    ants_dataset = MyData(root_dir,ants_label_dir)
    bees_dataset = MyData(root_dir,bees_label_dir)

    # 获得蚂蚁数据集中的数据  同 ants_dataset.__getitem__(0)
    img,label = ants_dataset[0]
    # 显示图片
    img.show()
    print(label)

    img,label = bees_dataset[0]
    img.show()
    print(label)

    # 将两个数据集进行合并
    train_dataset  = ants_dataset + bees_dataset
    # 得到总数据集的数据量 同 train_dataset.__len__()
    print(len(train_dataset))

    img,label = train_dataset[0]
    img.show()

**将数据集进行拼接:**当开源的数据集存在不足时,我们可以自己制作相关数据集与已有的数据集进行拼接,再去训练

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值