PyTorch深度学习快速入门(b站小土堆)P3笔记
加载数据初认识
实战操作
from torch.utils.data import Dataset
from PIL import Image
//读取图片
import os
//创建MyData类,继承Dataset
class MyData(Dataset):
def __init__(self,root_dir,label_dir):
self.root_dir = root_dir
//root_dir = "dataset/train"
self.label_dir = label_dir
//label_dir = "ants"
self.path = os.path.join(self.root_dir,self.label_dir)
//两个路径的拼接
self.imag_path = os.listdir(self.path)
//路径dataset/train/ants下的所有图片
def __getitem__(self,idx):
//idx是编号
img_name = self.imag_path[idx]
//如果idx&#