Pytorch框架使用 自建数据集
集合划分
首先将数据集放置如下:
├─class_1
│ data_1
│ ...
│ data_n
├─class_2
│ data_1
│ ...
│ data_n
├─...
│ data_1
│ ...
│ data_n
└─class_n
data_1
...
data_n
数据集的划分主要借助sklearn模块,若要分为train、val、test三个集合:
from sklearn.model_selection import train_test_split
def train_test_val_split(x, y, val_ratio=0.1, test_ratio=0.1, random_state=22):
# random_state for reproduction
# shuffle must be 'True'
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=val_ratio + test_ratio,
random_state=random_state, shuffle=True)
x_val, x_test, y_val, y_test = train_test_split(x_test, y_test, test_size=test_ratio / (test_ratio + val_ratio),
random_state=random_state)
return x_train, y_train, x_test, y_test, x_val, y_val
若要划分为train、val两个集合:
def train_val_split(x, y, val_ratio=0.1, random_state=22):
x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=val_ratio, random_state=random_state, shuffle=True)
return x_train, y_train, x_val, y_val
标签映射
数据集的标签一般是一个字符,但是字符标签无法直接用于训练,因此需要对其进行编码。
常用的sklearn库提供了多种标签编码方式,如OneHotEncoder
, BinaryEncoder
和LabelEncoder
其中LabelEncoder是我比较常用的编码方式,简单来说就是把n个类别值编码为0~n-1之间的整数,建立起1-1映射
使用方法如下:
from sklearn.preprocessing import LabelEncoder
# 加载数据路径与标签
data_path = './dataset'
x, y = load_my_dataset(data_path) # x为数据list, y为标签list
# 对标签进行编码
le = LabelEncoder() # 把n个类别值编码为0~n-1之间的整数,建立起1-1映射
y = le.fit_transform(y).astype(np.int64)
在编码后,为了测试、在线推理等阶段还原显示真实的标签,可以将映射表储存成字典,然后保存成txt文件
# 保存编码映射表
idx2class_path = os.path.join(data_path, 'idx2class.txt')
idx2class_dict = {}
for cl in le.classes_:
idx2class_dict.update({le.transform([cl])[0]: cl})
dict2txt(idx2class_path, idx2class_dict)
# 将字典保存成txt文件
def dict2txt(text_path, data: dict):
# 先创建并打开一个文本文件
file = open(text_path, 'w')
# 遍历字典的元素,将每项元素的key和value分拆组成字符串,注意添加分隔符和换行符
# 字典输出的项是无序的,如果想按照字典的key排序输出的话,可以按照下面的方式实现
for k, v in sorted(data.items()):
file.write(str(k) + ' ' + str(v) + '\n')
# 注意关闭文件
file.close()
在训练代码等项目中,使用idx2class.txt
将模型输出的标签转换为真实标签:
import os
from tools.dict_txt_converter import *
class Idx2class(object):
def __init__(self, args):
self.data_dir = args.dataset.data_dir
self.idx2class_dict = dict()
self.class_name_ls = []
self.gen_get_idx2cls_file()
def gen_get_idx2cls_file(self):
# 判断是否存在idx和类别名的映射文件: idx2class.txt
idx2class_path = os.path.join(self.data_dir, 'idx2class.txt')
if not os.path.exists(idx2class_path):
class_name = []
for item in os.scandir(self.data_dir):
if item.is_dir():
class_name.append(item.name)
for key in range(len(class_name)):
self.idx2class_dict[key] = class_name[key]
dict2txt(idx2class_path, self.idx2class_dict)
else:
self.idx2class_dict = txt2dict(idx2class_path)
# 转为class_name_list
for i in sorted(self.idx2class_dict):
self.class_name_ls.append(self.idx2class_dict[i])
def get_cls_name(self, cls_id):
return self.idx2class_dict[cls_id]
def get_cls_name_ls(self):
return self.class_name_ls
def dict2txt(text_path, data: dict):
# 先创建并打开一个文本文件
file = open(text_path, 'w')
# 遍历字典的元素,将每项元素的key和value分拆组成字符串,注意添加分隔符和换行符
# 字典输出的项是无序的,如果想按照字典的key排序输出的话,可以按照下面的方式实现
for k, v in sorted(data.items()):
file.write(str(k) + ' ' + str(v) + '\n')
# 注意关闭文件
file.close()
def txt2dict(text_path):
# 声明一个空字典,来保存文本文件数据
data = {}
# 打开文本文件
file = open(text_path, 'r')
# 遍历文本文件的每一行,strip可以移除字符串头尾指定的字符(默认为空格或换行符)或字符序列
for line in file.readlines():
line = line.strip()
k = line.split(' ')[0]
v = line.split(' ')[1]
data[k] = v
# 依旧是关闭文件
file.close()
return data
完整代码:
import os
from pathlib import Path
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
def dict2txt(text_path, data: dict):
# 先创建并打开一个文本文件
file = open(text_path, 'w')
# 遍历字典的元素,将每项元素的key和value分拆组成字符串,注意添加分隔符和换行符
# 字典输出的项是无序的,如果想按照字典的key排序输出的话,可以按照下面的方式实现
for k, v in sorted(data.items()):
file.write(str(k) + ' ' + str(v) + '\n')
# 注意关闭文件
file.close()
def save2txt(file, data, label):
# 判断文件是否存在,不存在则创建
data_num = len(data)
with open(file, "w") as f:
for idx in range(data_num):
temp = f"{data[idx]}\t{label[idx]}\n"
print(temp)
f.writelines(temp)
def train_test_val_split(x, y, val_ratio=0.1, test_ratio=0.1, random_state=22):
# random_state for reproduction
# shuffle must be 'True'
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=val_ratio + test_ratio,
random_state=random_state, shuffle=True)
x_val, x_test, y_val, y_test = train_test_split(x_test, y_test, test_size=test_ratio / (test_ratio + val_ratio),
random_state=random_state)
return x_train, y_train, x_test, y_test, x_val, y_val
def train_val_split(x, y, val_ratio=0.1, random_state=22):
x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=val_ratio, random_state=random_state,
shuffle=True)
return x_train, y_train, x_val, y_val
def load_my_dataset(data_path):
data_path = Path(data_path)
dirs = [e for e in data_path.iterdir() if e.is_dir()]
x = []
y = []
for each_path in dirs:
# 分隔符规范化
each_path = os.path.normpath(each_path)
cls = each_path.split(os.path.sep)[-1]
for file in os.listdir(each_path):
if not os.path.isdir(file):
whole_path = os.path.join(each_path, file)
prefix = whole_path.split(os.path.sep)[0]
x.append(whole_path.replace(prefix, ''))
y.append(cls)
return x, y
if __name__ == '__main__':
# 加载数据路径与标签
data_path = './dataset'
x, y = load_my_dataset(data_path)
# 对标签进行编码
le = LabelEncoder() # 把n个类别值编码为0~n-1之间的整数,建立起1-1映射
y = le.fit_transform(y).astype(np.int64)
# 保存编码映射表
idx2class_path = os.path.join(data_path, 'idx2class.txt')
idx2class_dict = {}
for cl in le.classes_:
idx2class_dict.update({le.transform([cl])[0]: cl})
dict2txt(idx2class_path, idx2class_dict)
# 划分
train_save_path = os.path.join(data_path, 'train.txt')
val_save_path = os.path.join(data_path, 'val.txt')
test_save_path = os.path.join(data_path, 'test.txt')
random_state = 2
val_ratio = 0.2
test_ratio = 0
if test_ratio > 0:
x_train, y_train, x_test, y_test, x_val, y_val = train_test_val_split(x, y, val_ratio, test_ratio, random_state)
save2txt(test_save_path, x_test, y_test)
else:
x_train, y_train, x_val, y_val = train_val_split(x, y, val_ratio, random_state)
save2txt(train_save_path, x_train, y_train)
save2txt(val_save_path, x_val, y_val)
Transform类
Transform用于增强输入数据,常见的有归一化、随机裁剪、旋转、灰度化等
Dataset类
模板如下:
class MyDataset(torch.utils.data.Dataset):#需要继承torch.utils.data.Dataset
def __init__(self):
#对继承自父类的属性进行初始化
super(MyDataset,self).__init__()
# TODO
#1、初始化一些参数和函数,方便在__getitem__函数中调用。
#2、制作__getitem__函数所要用到的图片和对应标签的list。
#也就是在这个模块里,我们所做的工作就是初始化该类的一些基本参数。
pass
def __getitem__(self, index):
# TODO
#1、根据list从文件中读取一个数据(例如,使用numpy.fromfile,PIL.Image.open)。
#2、预处理数据(例如torchvision.Transform)。
#3、返回数据对(例如图像和标签)。
#这里需要注意的是,这步所处理的是index所对应的一个样本。
pass
def __len__(self):
#返回数据集大小
return len()
完整代码如下:
import os.path
from datasets.uac_data_aug import *
import torch
def load_data(data_path):
data = []
label = []
with open(data_path, 'r') as f:
lines = f.readlines()
for line in lines:
line = line.split('\t')
data.append(line[0])
label.append(line[1].replace('\n', ''))
return data, label
class MyDataset(torch.utils.data.Dataset): # 需要继承torch.utils.data.Dataset
def __init__(self, data_path, data_type, transform=None):
# 根据data_type确定data_path
if data_type == 'train':
self.test = False
self.data_path = os.path.join(data_path, 'train.txt')
elif data_type == 'test':
self.test = True
self.data_path = os.path.join(data_path, 'test.txt')
elif data_type == 'val':
self.test = True
self.data_path = os.path.join(data_path, 'val.txt')
else:
raise ValueError('Error Input Data Type!')
# 加载数据
self.data, self.label = load_data(self.data_path)
# 根据transform对数据进行处理
if transform is None:
self.transforms = Compose([
Reshape()
])
else:
self.transforms = transform
def __getitem__(self, index):
if self.test:
seq = self.data[index]
seq = self.transforms(seq)
return seq, index
else:
seq = self.data[index]
label = self.labels[index]
seq = self.transforms(seq)
return seq, label
def __len__(self):
# 返回数据集大小
return len(self.data)
if __name__ == '__main__':
data_path = r'.\data'
data_type = 'val'
mydataset= MyDataset(data_path, data_type)
Dataloader类
import torch.utils.data
train_data_path = r'.\data'
train_dataset = MyDataset(train_data_path, 'train')
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
shuffle=True,
batch_size=args.train.batch_size
)