import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
import pandas as pd
from skimage import io
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, csv_file, root_dir, transform=None):
# 读取标签文件
self.annotations = pd.read_csv(os.path.join(root_dir, csv_file))
# 定义文件目录
self.root_dir = root_dir
# 定义transform
self.transform = transform
# 返回数据集长度
def __len__(self):
return len(self.annotations)
# 获取数据的方法,会和Dataloader连用
def __getitem__(self, index):
# 获取图片路径,0表示csv文件的第一列
img_path = os.path.join(self.root_dir, self.annotations.iloc[index, 0])
# 读取图片
image = io.imread(img_path)
# 获取图片对应的标签,1表示csv文件的第二列
label = torch.tensor(int(self.annotations.iloc[index, 1]))
# 如果使用时附加了transform参数,则对图片应用转换
if self.
Python加载带有csv标签文件的图片数据集
于 2022-09-28 11:59:44 首次发布