文章目录
一、涉及到的知识点
1.torch.utils.data
torch.utils.data 模块是 PyTorch 中用于处理数据加载和预处理的模块,提供了一些类和函数,帮助用户更方便地准备数据用于训练神经网络模型。
torch.utils.data 模块中一些常用的类和函数:
Dataset 类: Dataset 类是 torch.utils.data 模块中最重要的类之一,用户可以通过继承该类来自定义数据集。需要实现 __len__ 方法和 __getitem__ 方法,分别用于返回数据集的长度和获取单个样本。
DataLoader 类: DataLoader 类用于对 Dataset 对象进行封装,实现数据的批量加载和数据的打乱等功能。用户可以通过设置 batch_size、shuffle 等参数来自定义数据加载的方式。
Sampler 类: Sampler 类用于指定样本的采样策略,例如随机采样、顺序采样等。PyTorch 提供了一些内置的采样器,也可以自定义采样器来满足特定需求。
Transforms 类: transforms 模块提供了对数据进行预处理和转换的函数,如图片缩放、裁剪、正则化等。用户可以将这些函数组合成一个 Compose 对象,然后应用到数据集中。
IterableDataset 类: IterableDataset 类是另一种数据集类,用户可以通过继承该类来实现可迭代的数据集,适用于数据量较大无法一次性加载到内存的情况。
2.from PIL import Image
from PIL import Image 是 Python 中常用的导入 PIL(Python Imaging Library)模块中的 Image 类的语句。PIL 库提供了丰富的图像处理功能,包括打开、保存、调整大小、旋转、裁剪等图像操作。
一些函数:
open():打开一张图片文件,返回一个 Image 对象。
save():保存图像到指定路径。
show():显示图像。
resize():调整图像尺寸。
rotate():旋转图像。
crop():裁剪图像。
convert():转换图像的色彩模式。
thumbnail():生成缩略图。
paste():将一个图像粘贴到另一个图像上。
filter():应用滤镜效果。
transpose():翻转图像。
split():分离图像的通道。
3.os
os 模块是 Python 中用于与操作系统进行交互的标准库模块之一。它提供了许多函数用于创建、删除、移动文件和目录,以及与操作系统交互的其他功能。
一些常用的 os 模块函数:
os.getcwd():获取当前工作目录的路径。
os.chdir(path):改变当前工作目录到指定路径。
os.listdir(path):返回指定目录下的所有文件和目录的名称列表。
os.mkdir(path):创建一个新目录。
os.rmdir(path):删除指定目录。
os.remove(path):删除指定文件。
os.path.exists(path):检查指定路径(文件或目录)是否存在。
os.path.isfile(path):检查指定路径是否为文件。
os.path.isdir(path):检查指定路径是否为目录。
os.path.join(path1,path2,path3…):将两个或多个路径组合成一个完整的路径
4.with
在 Python 中,with 语句是用于简化管理资源的一种方式,特别是在处理文件 I/O 或者网络连接等需要手动关闭的资源时非常有用。with 语句可以帮助你确保在代码块执行完毕后自动关闭资源,而不需要显式地调用关闭方法。
with open('file.txt', 'r') as file:
data = file.read()
# 在这里执行文件操作,不需要显式关闭文件
# 在这里,文件已经被自动关闭
with open(‘file.txt’, ‘r’) as file: 创建了一个文件对象 file,并在 with 代码块中使用它。当代码块执行完毕时,无论是否发生异常,Python 都会自动调用文件对象的 close() 方法来关闭文件。这样可以确保资源被及时释放,避免资源泄漏。
二、代码
1.获取数据集的代码 read_data.py
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self,root_dir,label_dir):
self.root_dir=root_dir
self.label_dir=label_dir
self.path=os.path.join(self.root_dir,self.label_dir)
self.img_path=os.listdir(self.path)
def __getitem__(self,idx):
img_name=self.img_path[idx]
img_item_path=os.path.join(self.path,img_name)
img=Image.open(img_item_path)
label=self.label_dir
return img,label
def __len__(self):
return len(self.img_path)
root_dir="dataset/train"
ants_label_dir="ants"
bees_label_dir="bees"
ants_dataset=MyData(root_dir,ants_label_dir)
bees_dataset=MyData(root_dir,bees_label_dir)
train_dataset=ants_dataset+bees_dataset
2.更改数据集存储方式的代码 rename_data.py
import os
root_dir = "dataset/train"
target_dir = "ants_img"
img_list = os.listdir(os.path.join(root_dir, target_dir))
label = target_dir.split('_')[0]
out_dir = "ants_label"
for img_name in img_list:
file_name = img_name.split('.jpg')[0]
with open(os.path.join(root_dir, out_dir, "{}.txt".format(file_name)), 'w') as f:
f.write(label)