DataSet类(1)
一、DataSet类
**原文链接:**https://zhuanlan.zhihu.com/p/500839903
DataSet类,可以帮助我们提取我们所需要的数据,用子类继承的DataSet类,给每一个数据进行编号(idx),在后面的神经网络中,初始化DataSet子类的实例后,就可以通过编号去实例对象中读取相应的数据,会自动调用__getitem__方法。同时子类对象也会获得相应真实的Label
**DataSet作用:**提供一种方式去获取数据及其对应真实的Label
DataSet类的子类,需要重写的函数:
获取每一个数据、以及其对应的Label
统计数据集中的数据数量
1.DataSet类官方解读
在pycharm终端中输入:
得到以下输出:
Help on class Dataset in module torch.utils.data.dataset:
class Dataset(typing.Generic)
| An abstract class representing a :class:`Dataset`.
|
| All datasets that represent a map from keys to data samples should subclass
| it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
| data sample for a given key. Subclasses could also optionally overwrite
| :meth:`__len__`, which is expected to return the size of the dataset by many
| :class:`~torch.utils.data.Sampler` implementations and the default options
| of :class:`~torch.utils.data.DataLoader`.
|
| .. note::
| :class:`~torch.utils.data.DataLoader` by default constructs a index
| sampler that yields integral indices. To make it work with a map-style
| dataset with non-integral indices/keys, a custom sampler must be provided.
|
| Method resolution order:
| Dataset
| typing.Generic
| builtins.object
|
| Methods defined here:
|
| __add__(self, other:'Dataset[T_co]') -> 'ConcatDataset[T_co]'
|
| __getattr__(self, attribute_name)
|
| __getitem__(self, index) -> +T_co
|
| ----------------------------------------------------------------------
| Class methods defined here:
|
| register_datapipe_as_function(function_name, cls_to_register, enable_df_api_tracing=False) from typing.GenericMeta
|
| register_function(function_name, function) from typing.GenericMeta
|
| ----------------------------------------------------------------------
| Data descriptors defined here:
|
| __dict__
| dictionary for instance variables (if defined)
|
| __weakref__
| list of weak references to the object (if defined)
|
| ----------------------------------------------------------------------
| Data and other attributes defined here:
|
| __abstractmethods__ = frozenset()
|
| __annotations__ = {'functions': typing.Dict[str, typing.Callable]}
|
| __args__ = None
|
| __extra__ = None
|
| __next_in_mro__ = <class 'object'>
| The most base type
|
| __orig_bases__ = (typing.Generic[+T_co],)
|
| __origin__ = None
|
| __parameters__ = (+T_co,)
|
| __tree_hash__ = -9223371872509358054
|
| functions = {'concat': functools.partial(<function Dataset.register_da...
|
| ----------------------------------------------------------------------
| Static methods inherited from typing.Generic:
|
| __new__(cls, *args, **kwds)
| Create and return a new object. See help(type) for accurate signature.
该类是一个抽象类,所有的数据集想要在数据与标签之间建立映射,都需要继承这个类,所有的子类都需要重写__getitem__方法,该方法根据索引值获取每一个数据并且获取其对应的Label,子类也可以重写__len__方法,返回数据集的size大小
二、蚂蚁蜜蜂分类数据集:
https://download.pytorch.org/tutorial/hymenoptera_data.zip
代码.py
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self,root_dir,label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
# 将路径进行合并
self.path = os.path.join(self.root_dir,self.label_dir)
# os.listdir(self.path) 将路径下的文件 装入列表进行返回
self.img_path = os.listdir(self.path)
def __getitem__(self, index): # idx是index的简称,就是一个编号,以便以后数据集获取后,我们使用索引编号访问每个数据
img_name = self.img_path[index] # 按照索引取得列表下的文件名字
# 路径合并,得到对应索引的文件路径
img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
# 获得路径下的图片
img = Image.open(img_item_path)
label = self.label_dir
return img,label
def __len__(self):
# 得到数据集的长度
return len(self.img_path)
if __name__ == '__main__':
root_dir = r"/home/zxz/DEEPLEARNING/DEMO/dataset_1/hymenoptera_data/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
# 得到ants数据集
ants_dataset = MyData(root_dir,ants_label_dir)
bees_dataset = MyData(root_dir,bees_label_dir)
# 获得蚂蚁数据集中的数据 同 ants_dataset.__getitem__(0)
img,label = ants_dataset[0]
# 显示图片
img.show()
print(label)
img,label = bees_dataset[0]
img.show()
print(label)
# 将两个数据集进行合并
train_dataset = ants_dataset + bees_dataset
# 得到总数据集的数据量 同 train_dataset.__len__()
print(len(train_dataset))
img,label = train_dataset[0]
img.show()
**将数据集进行拼接:**当开源的数据集存在不足时,我们可以自己制作相关数据集与已有的数据集进行拼接,再去训练