参考视频:
P5. PyTorch加载数据初认识_哔哩哔哩_bilibili
P6. Dataset类代码实战_哔哩哔哩_bilibili
1. dataset
1.1. 作用:
- 获取数据&其label
- 告知共有多少数据
1.2. 实战
1.2.1. 数据集准备
下载数据集:https://download.pytorch.org/tutorial/hymenoptera_data.zip
该数据集是蚂蚁和蜜蜂的二分类数据集,label就是文件名字,数据集结构如下
hymenoptera_data
|__train
| |__ants
| |__bees
|
|__val
|__ants
|__bees
常见的数据集组织形式:
- 文件名是label
- label写在.txt里
- 图片名就是label
1.2.2. python文件
from torch.utils.data import Dataset # 导入包
from PIL import Image # 用来读取图片
import os # 用来获取所有数据的地址
class MyData(Dataset):
def __init__(self, root_dir, label_dir):
#root_dir是train或val文件夹的地址
#label_dir是ants或bees
self.root_dir = root_dir
self.label_dir = label_dir
# 将root_dir和label_dir拼接在一起得到图片的地址eg:train/ants
self.path = os.path.join(self.root_dir,self.label_dir)
# 得到图片地址下所有图片的名字列表eg:['1.jpg','2.jpg','3.jpg']
self.img_path = os.listdir(self.path)
def __getitem__(self,idx):
#某张图片的名字eg:'1.jpg'
img_name = self.img_path[idx]
# 某张图片的地址eg:train/ants/1.jpg
img_item_path = os.path.join(self.root_dir,self_label_dir,img_name)
# label是文件名eg:ants
label = self.label_dir
return img_item_path,label # 返回第idx张图片的地址和label
def __len__(self):
return len(self.img_path) #返回总共有多少张图片
root_dir = 'hymenoptera_data/train' # 输入自己的数据地址
ants_label_dir = 'ants'
bees_label_dir = 'bees'
ants_dataset = MyData(root_dir, ants_label_dir) #创建一个实例
bees_dataset = MyData(root_dir, bees_label_dir) #创建一个实例
train_dataset = ants_dataset + bees_dataset # 拼接两个数据集=总的训练集
2. dataloader
2.1. 作用:
为后面的网络提供不同的数据形式(让网络能够一个batch一个batch的处理数据)
问题:
- 如何使用os、PIL
- 不同的数据集组织形式如何加载