pytorch导入数据集

1、概念:

Dataset:一种数据结构,存储数据及其标签

Dataloader:一种工具,可以将Dataset里的数据分批、打乱、批量加载并进行迭代等

(方便模型训练和验证)

Dataset就像一个大书架,存放着带有标签的数据书籍,并且这些书有编号(0,1,2...);

而Dataloader就像一个图书管理员,负责从书架上按需取出书籍并分批提供给读者。

2、Dataset的组织形式

train:训练集  val:验证集

一种方式是label作为数据文件夹的名字,

另一种方式是label和数据本身分开成两个文件夹(label文件夹里装的是和每个数据对应的.txt)

3、处理图像:PIL(Python Imaging Library)

pip install Pillow安装PIL
from PIL import Image

引入Image类(代表图像对象,

可以通过创建Image实例来操作图像)

img=Image.open('图像路径') 打开图像img.show() 显示图像
print(img.size) 输出(宽度,高度)

print(img.format)

输出图像格式(JPEG、PNG等)

resized_img=img.resize((宽度,高度)) 调整大小
resized_img=img.save('新路径') 保存为新文件

4、处理目录和文件:os

import os
cur_dir=os.getcwd()获取当前工作目录
files=os.listdir(cur_dir)列举当前目录下的所有子目录(文件和文件夹)
os.makedirs('new_folder')创建新文件夹(如果不存在)
os.remove('file.txt')删除文件(os.rmdir('empty_folder')删除空文件夹)
os.path.exists('some_path')检查路径是否存在
file_path=os.path.join('folder','file.txt')拼接路径
abs_path=os.path.abspath('file.txt)获取文件的绝对路径

5、代码

from torch.utils.data import Dataset #从torch的常用工具箱utils中拿data工具,然后引入Dataset类
from PIL import Image #处理图片要用到
import os #访问目录、获取图片的地址要用到

class MyData(Dataset): #让MyData类继承Dataset类
    def __init__(self,root_dir,label_dir): #数据集的初始化:要用到根目录和标签目录(这里把label作为数据文件夹的名字了)
        self.root_dir=root_dir
        self.label_dir=label_dir
        self.path=os.path.join(self.root_dir,self.label_dir) #根目录+标签目录=数据集的路径
        self.img_dir_list=os.listdir(self.path) #列举数据集目录下的每个数据(文件)

    def __getitem__(self,idx): #获取索引对应的数据
        img_dir=self.img_dir_list[idx] #得到索引对应的数据文件
        img_path=os.path.join(self.root_dir,self.label_dir,img_dir) #数据集路径+数据文件=数据文件路径
        img=Image.open(img_path)
        label=self.label_dir
        return img,label

    def __len__(self):
        return len(self.img_dir_list) #数据长度=数据集目录下的子文件数量

root_dir=r"dataset/hymenoptera_data/train"
ants_label_dir="ants"
ants_dataset=MyData(root_dir,ants_label_dir)
bees_label_dir="bees"
bees_dataset=MyData(root_dir,bees_label_dir)

train_dataset=ants_dataset+bees_dataset

### 回答1: 首先,你需要将本地数据集的图片等文件存放在一个文件夹中。然后,你可以使用 PyTorch 的 `torchvision.datasets.ImageFolder` 函数来读取数据。 例如: ```python import torch import torchvision # 定义数据集的路径 data_dir = '/path/to/your/data' # 定义数据加载器 data_loader = torch.utils.data.DataLoader( torchvision.datasets.ImageFolder(data_dir, transform=transforms), batch_size=4, shuffle=True, num_workers=4 ) # 使用 for 循环来遍历数据加载器并处理数据 for inputs, labels in data_loader: # 处理数据 ... ``` 在上面的代码中,`data_dir` 变量表示数据集所在的文件夹的路径,`transforms` 是一个可选的变换列表,用于对输入数据进行预处理。`batch_size` 参数表示每次返回的图像数量,`shuffle` 参数表示是否打乱数据,`num_workers` 参数表示使用多少个子进程来加载数据。 在处理数据时,每次迭代数据加载器都会返回一个批次的输入数据和对应的标签。你可以使用这些数据来训练你的模型。 希望这些信息对你有帮助。 ### 回答2: 导入本地数据集PyTorch可以通过以下步骤完成。 1. 首先,我们需要确保数据集已经下载并保存在本地计算机的某个文件夹下。比如,我们将数据集保存在名为"data"的文件夹中。 2. 接下来,我们需要导入PyTorch库,并使用`torchvision.datasets.ImageFolder`来创建一个数据集对象。ImageFolder是一个PyTorch提供的可以用于处理图像数据的数据集类。 ```python import torch from torchvision import datasets # 定义数据集路径 data_path = 'data/' # 创建一个数据集对象并指定数据集路径和数据变换(若有需要) dataset = datasets.ImageFolder(root=data_path, transform=None) ``` 3. 假设我们通过上述步骤已经成功加载了数据集。接下来,我们可以使用`torch.utils.data.DataLoader`来创建一个数据加载器,将数据集加载到模型中进行训练或测试。数据加载器可以帮助我们方便地处理数据批次、数据乱序和并行加载等问题。 ```python from torch.utils.data import DataLoader # 定义批处理大小和是否打乱数据 batch_size = 32 shuffle = True # 创建一个数据加载器并指定数据集和其他参数 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) ``` 通过以上步骤,我们就可以将本地数据集成功导入PyTorch中了。从而可以方便地使用PyTorch提供的功能进行数据预处理、模型构建和训练等操作。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值