通过重写Dataset类,对自己制作的数据集进行读取后传给DataLoader。主要用来完成从哪里读取数据和标签的功能。主要是__getitem__(返回数据集和标签)和__len__(返回数据的长度)这两个方法。
import numpy as np
import torch
import os
from PIL import Image
from torch.utils.data import Dataset
class MyDataset_1(Dataset):
"""
通过包含数据路径和标签的txt文件读取
txt_path:txt文本路径, 该文本包含了图像的路径信息, 以及标签信息
transform: 数据处理,对图像进行随机裁剪, 以及转换成tensor
"""
def __init__(self, txt_path, transform=None, target_transform=None):
super(MyDataset_1, self).__init__()
fh = open(txt_path)
imgs = []
# 一行一行读取txt文件
for line in fh:
line = line.rstrip() # 这一行就是图像的路径以及标签
words = line.split()
imgs.append((words[0], int(words[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.imgs)
def __getitem__(self, index):
fn, label = self.imgs[index] # 通过index索引返回一个图像路径fn与标签label
img = Image.open(fn)
if self.transform:
img = self.transform(img)
return img, label
class MyDataset_2(Dataset):
"""
通过标签文件读取,csv文件前面是numpy数据,最后一列是label
"""
def __init__(self, csv_file):
super(MyDataset_2, self).__init__()
# xy是一个容器, 通过读取一个包含数据和标签信息的文件
xy = np.loadtxt(csv_file, delimiter=',', dtype=np.float32)
self.x_data = torch.from_numpy(xy[:, 0:-1])
self.y_data = torch.from_numpy(xy[:, -1])
self.len = len(xy) # 给后面的__len__()使用
def __len__(self):
return self.len
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
class MyDataset_3(Dataset):
"""
每个文件夹是一个类,每个文件夹中都是该类的图片(这种方法就等同于torchvision.datasets.ImageFolder)
"""
def __init__(self, dirname, transform=None):
super(MyDataset_3, self).__init__()
self.classes = os.listdir(dirname) # 有多少个目录就等于多少个类别,这边获得类别名
self.images = []
self.transform = transform
for i, classes in enumerate(self.classes):
classes_path = os.path.join(dirname, classes) # 类别目录
for image_name in os.listdir(classes_path): # 便利该类别中的图片
self.images.append((os.path.join(classes_path, image_name), i)) # 获得图片路径和类别名索引
def __len__(self):
return len(self.images)
def __getitem__(self, index):
image_name, classes = self.images[index]
image = Image.open(image_name)
if self.transform:
image = self.transform(image)
return image, classes