数据载入由什么组成?
数据载入由dataset和dataloader组成。
dataset:提供一种方式去获取数据及其label
dataloader: 为后面网络提供不同的数据形式
1. Dataset的功能
dataset主要为了实现两个功能
1.如何获取每个数据及其label
2.告诉我们总共有多少的数据
2. Dataset代码:
1) 查看官方文档解释
首先,在anaconda prompt中输入如下代码,打开jupyter环境
conda activate <pytorch环境名称> #激活pytorch环境
jupyter notebook #打开jupyter
然后,在jupyter中创建新的文件,并输入以下指令既可以看到关于dataset官方文档解释。
from torch.utils.data import Dataset
help(Dataset)
2) pycharm中进行程序编写
(1)下载数据集
蚁蜜蜂分类数据集
https://download.pytorch.org/tutorial/hymenoptera_data.zip
(2)
建立dataset文件,并将数据集放入程序下
(3)编写数据集载入程序,实现dataset两个功能,第一,如何获取每个数据及其label
第二,告诉我们总共有多少的数据。
from torch.utils.data import Dataset
from PIL import Image
import os
from torchvision import transforms
class mydata(Dataset):
#设置全局参数
def __init__(self,root_dir,label_dir):
self.root_dir=root_dir
self.label_dir=label_dir
# 获得图片的路径地址
self.path=os.path.join(self.root_dir,self.label_dir) #os.path.join()函数用于路径拼接文件路径,可以传入多个路径。如果不存在以’/’开始的参数,则函数会自动加上
# 获得图片的所有列表
self.img_path=os.listdir(self.path) #os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。
#获取每一个图片
def __getitem__(self, item):
#获取单张图片名称
img_name=self.img_path[item]
#获取单张图片相对路径
img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
#读取单张图片
img=Image.open(img_item_path)
img=img.resize((256,256),Image.ANTIALIAS) #统一图片尺寸
trans = transforms.ToTensor() #转换为tensor类型
img_tensor = trans(img)#转换为tensor类型
#获取lable
label=self.label_dir
return img_tensor, label
#列表有多长
def __len__(self):
return len(self.img_path)
3. Dataloader的功能
dataloader的功能是为了实现从dataset中取数据,例如,每次取多少数据?,数据集是否打乱?,加载过程是单进程还是多进程?,如果最后剩余数据不足一次需要获取数据,剩余数据是否舍弃。
4. Dataloader代码
1) 查看官方文档解释
在jupyter中创建新的文件,并输入以下指令既可以看到关于dataset官方文档解释。
from torch.utils.data import DataLoader
help(DataLoader)
2) pycharm中进行程序编写
新建.py文件,写入以下内容
from dataset import mydata
from torch.utils.data import DataLoader
import torch
#准备测试数据集
root_dir= "dataset/val"
bees_label_dir="bees"
test_dataset=mydata(root_dir,bees_label_dir)#输入数据集路径
#数据载入
test_loader=DataLoader(dataset=test_dataset,batch_size=4,shuffle=True,num_workers=0,drop_last=False)
for data1 in test_loader:
imgs,labs=data1
print(imgs.shape)
print(labs)
输出如下内容则表示数据载入成功
感谢: PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】
视频网址:https://www.bilibili.com/video/BV1hE411t7RN?p=15&vd_source=5b6e0605c1ed0f1db9c92503dd5994e0