Pytorch数据封装进入网络前的几种方式
一、利用Pytorch的库方法
简单来说就是使用datasets.ImageFolder与torch.utils.data.DataLoader这两种方法。
数据存放格式需要如下
如下是代码示例
1、数据处理
train_transforms = transforms.Compose([
transforms.Resize((64, 64)),
transforms.RandomSizedCrop(48),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
2、数据读取与处理
train_datasets = datasets.ImageFolder(os.path.join(data_dir, 'train'), train_transforms)
# data_dir是总路径
3、数据封装
train_loaders = torch.utils.data.DataLoader(
dataset=train_datasets,
batch_size=batch_size,
shuffle=True,
num_workers=0
)
比较简单,接下来就循环读取并输入网络即可
二、自己构建数据封装
一般这种的比较常用,相比较调用库的方式稍显麻烦一点
1、图片标签读取并存入txt文件
数据存放如图所示,下面代码是生成txt文件,调用函数Data_division(‘数据总路径’,’保存训练txt文件路径‘,‘保存验证txt文件路径’)
CreateDataset.py
import os
import random
# 用于产生txt文件
def Data_division(data_path,save_txt_train_path,save_txt_eval_path):
# 将读取所有数据路径并存放在data_list当中
class_label = 0
data_list = []
list = os.listdir(data_path) # ['dataset', 'rename.py']
for i in range(0, len(list)):
path = os.path.join(data_path, list[i])
if os.path.isdir(path): # ./data\dataset 判断是否是文件夹
# 从这里开始
for j in os.listdir(path):