视频教程:P5. PyTorch加载数据初认识_哔哩哔哩_bilibili
数据集:ants-bees二分类数据集
数据集下载地址: https://pan.baidu.com/s/1jZoTmoFzaTLWh4lKBHVbEA 密码: 5suq
代码如下(建议先跟着UP主边听边做,最后自己再敲一遍)
from torch.utils.data import Dataset
from PIL import Image
import os
class MyDataSet(Dataset):
def __init__(self, dir_root, label):
self.dir_root = dir_root
self.label = label
# 获取数据根目录
path = os.path.join(self.dir_root, self.label)
self.image_name_list = os.listdir(path)
def __getitem__(self, index):
image_name = self.image_name_list[index]
image_path = os.path.join(self.dir_root, self.label, image_name)
image = Image.open(image_path)
label = self.label
return image, label
def __len__(self):
return len(self.image_name_list)
data_ants = MyDataSet("data_set/train", "ants")
#image, label = data_ants[0]
data_bees = MyDataSet("data_set/train","bees")
data=data_ants+data_bees #两个数据集合并