深度学习pytorch之一步导入自己的训练集

使用pytorch导入自己的数据有两种方法:

第一种:使用torchvision工具包中的datasets.ImageFolder(该方法较为简单)
第二种:使用torch.utils.data.Dataset,自定义导入数据的方式(需要根据不同情况编写代码)

第一种:torchvision.datasets.ImageFolder

要求:专门对于分类问题,将不同标签的图片分别放在不同的文件夹下,如图(将猫狗的图片分别放在两个不同的文件夹下),cat和dog文件夹放在data文件夹下。
在这里插入图片描述

dataset = torchvision.datasets.ImageFolder('path')  # path:data文件夹的路径
第二种:自定义读取方式

要求:没有要求,可以是分类问题,也可以是回归问题(例如输入和输出同为图片)

需要自定义一个Dataset

from PIL import Image
from torch.utils.data import Dataset,DataLoader
class MyDataset(Dataset):

    def __init__(self, data_dir, transform=None):
        self.imgs = self.get_imgs(data_dir)
        self.transform = transform

    def __getitem__(self, index):
        img_path, label = self.imgs[index]
        img = Image.open(img_path)
        if self.transform is not None:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.imgs)
    
    def get_images(data_dir):
    	imgs = []
    	for root, dirs, _ in os.walk(data_dir):     # dirs 为各类名
    		for sub_dirs in dirs:
    			img_names = os.listdir(os.path.join(root, sub_dir))  # 图片路径
    			for i in range(len(img_names)):
    				img_name = img_names[i]    # 图片名
    				path_img = os.path.join(root, sub_dir, img_name)
    				imgs.append((path_img, int(dirs)))
trainset = MyDataset(train_dir,transforms)
trainloader = DataLoader(trainset, batch_size=1)

整个代码分三步:

  1. 需要自己先定义一个类,继承torch.utils.data.Dataset,并初始化参数:主要为设置图片的路径和预处理方法
class MyDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.imgs = self.get_imgs(data_dir)
        self.transform = transform

data_dir:图片保存的位置
transform:图像预处理方法(可以看博主的博客transforms了解)

  1. 自定义读取文件的路径
    创建一个空的list,将输入图片的路径和输出图片的路径以tuple的形式逐个存入。
    在本例中,图片以输入1,标签1,输入2,标签2,…的形式保存的。
    def get_images(data_dir):
    	imgs = []      # 创建一个空的list
    	for root, dirs, _ in os.walk(data_dir):     # 得到data_dir文件夹下所有的文件名(得到的dirs 为各类名)
    		for sub_dirs in dirs: 
    			img_names = os.listdir(os.path.join(root, sub_dir))  # 获得文件夹下所有图片路径
    			for i in range(len(img_names)//2):
    				img_input_name = img_names[i]    # 提取一个input图片名
    				img_label_name = img_name[i+1]  # 提取一个label图片名
    				path_img_1 = os.path.join(root, sub_dir, img_name) # 获得图片路径
    				path_img_2 = os.path.join(root, sub_dir, img_name) # 获得图片路径
    				imgs.append((path_img_1, path_img_2))

3.定义getitem,逐个读入图片
getitem为父类torch.utils.data.Dataset已经定义好的,它会逐个进行index=0,1,2,…。
只需要打开图片,进行图片预处理后,return即可。
定义len,返回样本数。

    def __getitem__(self, index):
        img_path, label = self.imgs[index]
        img = Image.open(img_path)    # 打开图片
        if self.transform is not None:
            img = self.transform(img)    # 图片预处理
        return img, label

    def __len__(self):
        return len(self.imgs)
补充知识点:
DataLoader

torch.utils.data.DataLoader:构建可迭代的数据装载器

DataLoader(dataset, batch_size=1, shuffle=False, num_works=0)

dataset:Dataset类,决定数据从哪儿读取及如何读取
batch_size:批大小
shuffle:每个epoch是否乱序
num_works:是否多进程读取数据

Dataset

torch.utils.data.Dataset:所有自定义的Dataset需要继承它,并且复写

class Dataset(object):
	def __init__(self):
		pass
	def __getitem__(self, index):
		pass
	def __len__(self, other):
		pass

len:返回数据集的大小
getitem:接受一个样本,返回一个索引

  • 11
    点赞
  • 63
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
PyTorch是一个开源的深度学习框架,它提供了丰富的工具和库,用于构建深度学习模型。在PyTorch中,我们可以通过构建神经网络层、优化器和损失函数来创建深度学习模型。下面就是一个PyTorch深度学习的demo。 首先,我们导入PyTorch库,然后定义一个简单的神经网络模型。比如说,我们可以创建一个包含两个全连接层的神经网络,其中第一个全连接层的输入维度是特征的维度,输出维度是隐藏层的大小,然后采用激活函数例如ReLU;第二个全连接层的输入维度是隐藏层的大小,输出维度是我们所需要的输出的大小。 接下来,我们定义损失函数和优化器。损失函数用于计算模型的预测输出和真实标签之间的差异,而优化器用于更新神经网络中的权重和偏置,以减小损失函数。在这个例子中,我们可以选择使用交叉熵损失函数和随机梯度下降优化器。 最后,我们加载数据集,将数据集分为训练集和测试集。然后我们可以迭代训练模型,在每一个epoch中,我们将训练数据输入到模型中,计算损失值,反向传播更新模型参数,然后在测试集上评估模型的性能。 通过这个简单的demo,我们可以了解如何使用PyTorch构建一个深度学习模型,并了解如何定义神经网络、损失函数和优化器,以及如何使用数据集训练和评估模型。希望这个demo可以帮助你入门PyTorch深度学习框架。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值