b站小土堆pytorch学习记录——P7 Dataset类代码实战

一、涉及到的知识点

1.torch.utils.data

torch.utils.data 模块是 PyTorch 中用于处理数据加载和预处理的模块,提供了一些类和函数,帮助用户更方便地准备数据用于训练神经网络模型。

torch.utils.data 模块中一些常用的类和函数:

Dataset 类: Dataset 类是 torch.utils.data 模块中最重要的类之一,用户可以通过继承该类来自定义数据集。需要实现 __len__ 方法和 __getitem__ 方法,分别用于返回数据集的长度和获取单个样本。

DataLoader 类: DataLoader 类用于对 Dataset 对象进行封装,实现数据的批量加载和数据的打乱等功能。用户可以通过设置 batch_size、shuffle 等参数来自定义数据加载的方式。

Sampler 类: Sampler 类用于指定样本的采样策略,例如随机采样、顺序采样等。PyTorch 提供了一些内置的采样器,也可以自定义采样器来满足特定需求。

Transforms 类: transforms 模块提供了对数据进行预处理和转换的函数,如图片缩放、裁剪、正则化等。用户可以将这些函数组合成一个 Compose 对象,然后应用到数据集中。

IterableDataset 类: IterableDataset 类是另一种数据集类,用户可以通过继承该类来实现可迭代的数据集,适用于数据量较大无法一次性加载到内存的情况。

2.from PIL import Image

from PIL import Image 是 Python 中常用的导入 PIL(Python Imaging Library)模块中的 Image 类的语句。PIL 库提供了丰富的图像处理功能,包括打开、保存、调整大小、旋转、裁剪等图像操作。

一些函数:
open():打开一张图片文件,返回一个 Image 对象。
save():保存图像到指定路径。
show():显示图像。
resize():调整图像尺寸。
rotate():旋转图像。
crop():裁剪图像。
convert():转换图像的色彩模式。
thumbnail():生成缩略图。
paste():将一个图像粘贴到另一个图像上。
filter():应用滤镜效果。
transpose():翻转图像。
split():分离图像的通道。

关于PIL的Image具体可参考这篇博客

3.os

os 模块是 Python 中用于与操作系统进行交互的标准库模块之一。它提供了许多函数用于创建、删除、移动文件和目录,以及与操作系统交互的其他功能。

一些常用的 os 模块函数:

os.getcwd():获取当前工作目录的路径。
os.chdir(path):改变当前工作目录到指定路径。
os.listdir(path):返回指定目录下的所有文件和目录的名称列表。
os.mkdir(path):创建一个新目录。
os.rmdir(path):删除指定目录。
os.remove(path):删除指定文件。
os.path.exists(path):检查指定路径(文件或目录)是否存在。
os.path.isfile(path):检查指定路径是否为文件。
os.path.isdir(path):检查指定路径是否为目录。
os.path.join(path1,path2,path3…):将两个或多个路径组合成一个完整的路径

4.with

在 Python 中,with 语句是用于简化管理资源的一种方式,特别是在处理文件 I/O 或者网络连接等需要手动关闭的资源时非常有用。with 语句可以帮助你确保在代码块执行完毕后自动关闭资源,而不需要显式地调用关闭方法。

with open('file.txt', 'r') as file:
    data = file.read()
    # 在这里执行文件操作,不需要显式关闭文件
# 在这里,文件已经被自动关闭

with open(‘file.txt’, ‘r’) as file: 创建了一个文件对象 file,并在 with 代码块中使用它。当代码块执行完毕时,无论是否发生异常,Python 都会自动调用文件对象的 close() 方法来关闭文件。这样可以确保资源被及时释放,避免资源泄漏。

二、代码

1.获取数据集的代码 read_data.py

from torch.utils.data import Dataset
from PIL import Image
import os

class MyData(Dataset):
    def __init__(self,root_dir,label_dir):
        self.root_dir=root_dir
        self.label_dir=label_dir
        self.path=os.path.join(self.root_dir,self.label_dir)
        self.img_path=os.listdir(self.path)


    def __getitem__(self,idx):
        img_name=self.img_path[idx]
        img_item_path=os.path.join(self.path,img_name)
        img=Image.open(img_item_path)
        label=self.label_dir
        return img,label

    def __len__(self):
        return len(self.img_path)


root_dir="dataset/train"
ants_label_dir="ants"
bees_label_dir="bees"
ants_dataset=MyData(root_dir,ants_label_dir)
bees_dataset=MyData(root_dir,bees_label_dir)

train_dataset=ants_dataset+bees_dataset

2.更改数据集存储方式的代码 rename_data.py

import os

root_dir = "dataset/train"
target_dir = "ants_img"
img_list = os.listdir(os.path.join(root_dir, target_dir))
label = target_dir.split('_')[0]
out_dir = "ants_label"

for img_name in img_list:
    file_name = img_name.split('.jpg')[0]
    with open(os.path.join(root_dir, out_dir, "{}.txt".format(file_name)), 'w') as f:
        f.write(label)
  • 19
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

云霄星乖乖的果冻

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值