使用pytorch导入自己的数据有两种方法:
第一种:使用torchvision工具包中的datasets.ImageFolder(该方法较为简单)
第二种:使用torch.utils.data.Dataset,自定义导入数据的方式(需要根据不同情况编写代码)
第一种:torchvision.datasets.ImageFolder
要求:专门对于分类问题,将不同标签的图片分别放在不同的文件夹下,如图(将猫狗的图片分别放在两个不同的文件夹下),cat和dog文件夹放在data文件夹下。
dataset = torchvision.datasets.ImageFolder('path') # path:data文件夹的路径
第二种:自定义读取方式
要求:没有要求,可以是分类问题,也可以是回归问题(例如输入和输出同为图片)
需要自定义一个Dataset
from PIL import Image
from torch.utils.data import Dataset,DataLoader
class MyDataset(Dataset):
def __init__(self, data_dir, transform=None):
self.imgs = self.get_imgs(data_dir)
self.transform = transform
def __getitem__(self, index):
img_path, label = self.imgs[index]
img = Image.open(img_path)
if self.transform is not None:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.imgs)
def get_images(data_dir):
imgs = []
for root, dirs, _ in os.walk(data_dir): # dirs 为各类名
for sub_dirs in dirs:
img_names = os.listdir(os.path.join(root, sub_dir)) # 图片路径
for i in range(len(img_names)):
img_name = img_names[i] # 图片名
path_img = os.path.join(root, sub_dir, img_name)
imgs.append((path_img, int(dirs)))
trainset = MyDataset(train_dir,transforms)
trainloader = DataLoader(trainset, batch_size=1)
整个代码分三步:
- 需要自己先定义一个类,继承torch.utils.data.Dataset,并初始化参数:主要为设置图片的路径和预处理方法
class MyDataset(Dataset):
def __init__(self, data_dir, transform=None):
self.imgs = self.get_imgs(data_dir)
self.transform = transform
data_dir:图片保存的位置
transform:图像预处理方法(可以看博主的博客transforms了解)
- 自定义读取文件的路径
创建一个空的list,将输入图片的路径和输出图片的路径以tuple的形式逐个存入。
在本例中,图片以输入1,标签1,输入2,标签2,…的形式保存的。
def get_images(data_dir):
imgs = [] # 创建一个空的list
for root, dirs, _ in os.walk(data_dir): # 得到data_dir文件夹下所有的文件名(得到的dirs 为各类名)
for sub_dirs in dirs:
img_names = os.listdir(os.path.join(root, sub_dir)) # 获得文件夹下所有图片路径
for i in range(len(img_names)//2):
img_input_name = img_names[i] # 提取一个input图片名
img_label_name = img_name[i+1] # 提取一个label图片名
path_img_1 = os.path.join(root, sub_dir, img_name) # 获得图片路径
path_img_2 = os.path.join(root, sub_dir, img_name) # 获得图片路径
imgs.append((path_img_1, path_img_2))
3.定义getitem,逐个读入图片
getitem为父类torch.utils.data.Dataset已经定义好的,它会逐个进行index=0,1,2,…。
只需要打开图片,进行图片预处理后,return即可。
定义len,返回样本数。
def __getitem__(self, index):
img_path, label = self.imgs[index]
img = Image.open(img_path) # 打开图片
if self.transform is not None:
img = self.transform(img) # 图片预处理
return img, label
def __len__(self):
return len(self.imgs)
补充知识点:
DataLoader
torch.utils.data.DataLoader:构建可迭代的数据装载器
DataLoader(dataset, batch_size=1, shuffle=False, num_works=0)
dataset:Dataset类,决定数据从哪儿读取及如何读取
batch_size:批大小
shuffle:每个epoch是否乱序
num_works:是否多进程读取数据
Dataset
torch.utils.data.Dataset:所有自定义的Dataset需要继承它,并且复写
class Dataset(object):
def __init__(self):
pass
def __getitem__(self, index):
pass
def __len__(self, other):
pass
len:返回数据集的大小
getitem:接受一个样本,返回一个索引