首先打开pycharm,将环境设置成python3.7(pytorch)ps:我使用的是3.7版本,可以根据自己的选择去选择不同的版本。
在正式开始之前我们先了解一下pytorch中的数据集:
PyTorch 数据集(Dataset),数据读取和预处理是进行机器学习的首要操作,PyTorch提供了很多方法来完成数据的读取和预处理。有 Dataset
,TensorDataset
,DataLoader
,ImageFolder,
在本文我们将使用Dataset来进行数据的加载操作。
from torch.utils.data import Dataset
torch.utils.data 是代表这一数据的抽象类,你可以自己定义你的数据类,继承和重写这个抽象类,非常简单,只需要定义__init__、__len__
和__getitem__
这个三个函数(ps:都是python中的魔法函数):
class Mydata(Dataset):
#数据类初始化
def __init__(self, root_dir, label_dir):
#定义数据主文件夹地址
self.root_dir = root_dir
#定义标签地址
self.label_dir = label_dir
#地址的拼接
self.path = os.path.join(self.root_dir, self.label_dir)
#数据列表
self.img_path = os.listdir(self.path)
#得到单个数据
def __getitem__(self, idx):
#得到单一数据的图片名(这里是使用列表)
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
#使用PIL中的Image库打开文件
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
其中__init__():
__init__() 方法可以包含多个参数,但必须包含一个名为 self 的参数,且必须作为第一个参数。也就是说,类的构造方法最少也要有一个 self 参数,仅包含 self 参数的 __init__() 构造方法,又称为类的默认构造方法。在这里我们设置了两个参数:root_dir(为所有数据的根目录) 和 label_dir(标签目录)
__getitem__():
凡是在类中定义了这个__getitem__ 方法,那么它的实例对象(假定为p),可以像这样
p[key] 取值,当实例对象做p[key] 运算时,会调用类中的方法__getitem__。
一般如果想使用索引访问元素时,就可以在类中定义这个方法(__getitem__(self, key) )。
__len__():
__len__():的作用是返回容器中元素的个数。
在明白以上这些知识以后,我们进行数据的读取:
首先先实例化数据集:
#创建两个实例
root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = Mydata(root_dir, ants_label_dir)
bees_dataset = Mydata(root_dir, bees_label_dir)
然后进行数据的读取:
ants
124