首先,我自己创建的彩色图像数据集是这样的:
标签是这样的:
#本文引荐了文章:http://t.csdn.cn/gkVNC;并作了注释与修改
#导入库
import os
import torch
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from torchvision.io import read_image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
#创建自定义数据集类
class Custom_Dataset(Dataset):
#函数,设置图像集路径索引、图像标签文件读取
def __init__(self, img_dir, img_label_dir, transform=None):
super().__init__()
self.img_dir = img_dir
self.img_labels = pd.read_csv(img_label_dir)
self.transform