Pytorch框架使用 自建数据集

集合划分

首先将数据集放置如下:

├─class_1
│      data_1
│      ...
│      data_n
├─class_2
│      data_1
│      ...
│      data_n
├─...
│      data_1
│      ...
│      data_n
└─class_n
        data_1
        ...
        data_n

数据集的划分主要借助sklearn模块,若要分为train、val、test三个集合:

from sklearn.model_selection import train_test_split

def train_test_val_split(x, y, val_ratio=0.1, test_ratio=0.1, random_state=22):
    # random_state for reproduction
    # shuffle must be 'True'
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=val_ratio + test_ratio,
                                                          random_state=random_state, shuffle=True)
    x_val, x_test, y_val, y_test = train_test_split(x_test, y_test, test_size=test_ratio / (test_ratio + val_ratio),
                                                        random_state=random_state)
    return x_train, y_train, x_test, y_test, x_val, y_val

若要划分为train、val两个集合:

def train_val_split(x, y, val_ratio=0.1, random_state=22):
    x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=val_ratio, random_state=random_state, shuffle=True)
    return x_train, y_train, x_val, y_val

标签映射

数据集的标签一般是一个字符,但是字符标签无法直接用于训练,因此需要对其进行编码。
常用的sklearn库提供了多种标签编码方式,如OneHotEncoder, BinaryEncoderLabelEncoder
其中LabelEncoder是我比较常用的编码方式,简单来说就是把n个类别值编码为0~n-1之间的整数,建立起1-1映射
使用方法如下:

from sklearn.preprocessing import LabelEncoder

# 加载数据路径与标签
data_path = './dataset'
x, y = load_my_dataset(data_path)	# x为数据list, y为标签list

# 对标签进行编码
le = LabelEncoder()    # 把n个类别值编码为0~n-1之间的整数,建立起1-1映射
y = le.fit_transform(y).astype(np.int64)

在编码后,为了测试、在线推理等阶段还原显示真实的标签,可以将映射表储存成字典,然后保存成txt文件

# 保存编码映射表
idx2class_path = os.path.join(data_path, 'idx2class.txt')
idx2class_dict = {}
for cl in le.classes_:
    idx2class_dict.update({le.transform([cl])[0]: cl})
dict2txt(idx2class_path, idx2class_dict)
# 将字典保存成txt文件
def dict2txt(text_path, data: dict):
    # 先创建并打开一个文本文件
    file = open(text_path, 'w')

    # 遍历字典的元素,将每项元素的key和value分拆组成字符串,注意添加分隔符和换行符
    # 字典输出的项是无序的,如果想按照字典的key排序输出的话,可以按照下面的方式实现
    for k, v in sorted(data.items()):
        file.write(str(k) + ' ' + str(v) + '\n')

    # 注意关闭文件
    file.close()

在训练代码等项目中,使用idx2class.txt将模型输出的标签转换为真实标签:

import os

from tools.dict_txt_converter import *


class Idx2class(object):
    def __init__(self, args):
        self.data_dir = args.dataset.data_dir
        self.idx2class_dict = dict()
        self.class_name_ls = []
        self.gen_get_idx2cls_file()

    def gen_get_idx2cls_file(self):
        # 判断是否存在idx和类别名的映射文件: idx2class.txt
        idx2class_path = os.path.join(self.data_dir, 'idx2class.txt')
        if not os.path.exists(idx2class_path):
            class_name = []
            for item in os.scandir(self.data_dir):
                if item.is_dir():
                    class_name.append(item.name)
            for key in range(len(class_name)):
                self.idx2class_dict[key] = class_name[key]
            dict2txt(idx2class_path, self.idx2class_dict)
        else:
            self.idx2class_dict = txt2dict(idx2class_path)

        # 转为class_name_list
        for i in sorted(self.idx2class_dict):
            self.class_name_ls.append(self.idx2class_dict[i])

    def get_cls_name(self, cls_id):
        return self.idx2class_dict[cls_id]

    def get_cls_name_ls(self):
        return self.class_name_ls

def dict2txt(text_path, data: dict):
    # 先创建并打开一个文本文件
    file = open(text_path, 'w')

    # 遍历字典的元素,将每项元素的key和value分拆组成字符串,注意添加分隔符和换行符
    # 字典输出的项是无序的,如果想按照字典的key排序输出的话,可以按照下面的方式实现
    for k, v in sorted(data.items()):
        file.write(str(k) + ' ' + str(v) + '\n')

    # 注意关闭文件
    file.close()


def txt2dict(text_path):
    # 声明一个空字典,来保存文本文件数据
    data = {}

    # 打开文本文件
    file = open(text_path, 'r')

    # 遍历文本文件的每一行,strip可以移除字符串头尾指定的字符(默认为空格或换行符)或字符序列
    for line in file.readlines():
        line = line.strip()
        k = line.split(' ')[0]
        v = line.split(' ')[1]
        data[k] = v

    # 依旧是关闭文件
    file.close()

    return data

完整代码:

import os
from pathlib import Path

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder


def dict2txt(text_path, data: dict):
    # 先创建并打开一个文本文件
    file = open(text_path, 'w')

    # 遍历字典的元素,将每项元素的key和value分拆组成字符串,注意添加分隔符和换行符
    # 字典输出的项是无序的,如果想按照字典的key排序输出的话,可以按照下面的方式实现
    for k, v in sorted(data.items()):
        file.write(str(k) + ' ' + str(v) + '\n')

    # 注意关闭文件
    file.close()


def save2txt(file, data, label):
    # 判断文件是否存在,不存在则创建
    data_num = len(data)
    with open(file, "w") as f:
        for idx in range(data_num):
            temp = f"{data[idx]}\t{label[idx]}\n"
            print(temp)
            f.writelines(temp)


def train_test_val_split(x, y, val_ratio=0.1, test_ratio=0.1, random_state=22):
    # random_state for reproduction
    # shuffle must be 'True'
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=val_ratio + test_ratio,
                                                        random_state=random_state, shuffle=True)
    x_val, x_test, y_val, y_test = train_test_split(x_test, y_test, test_size=test_ratio / (test_ratio + val_ratio),
                                                    random_state=random_state)

    return x_train, y_train, x_test, y_test, x_val, y_val


def train_val_split(x, y, val_ratio=0.1, random_state=22):
    x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=val_ratio, random_state=random_state,
                                                      shuffle=True)
    return x_train, y_train, x_val, y_val


def load_my_dataset(data_path):
    data_path = Path(data_path)
    dirs = [e for e in data_path.iterdir() if e.is_dir()]
    x = []
    y = []
    for each_path in dirs:
        # 分隔符规范化
        each_path = os.path.normpath(each_path)
        cls = each_path.split(os.path.sep)[-1]

        for file in os.listdir(each_path):
            if not os.path.isdir(file):
                whole_path = os.path.join(each_path, file)
                prefix = whole_path.split(os.path.sep)[0]

                x.append(whole_path.replace(prefix, ''))
                y.append(cls)

    return x, y


if __name__ == '__main__':
    # 加载数据路径与标签
    data_path = './dataset'
    x, y = load_my_dataset(data_path)

    # 对标签进行编码
    le = LabelEncoder()  # 把n个类别值编码为0~n-1之间的整数,建立起1-1映射
    y = le.fit_transform(y).astype(np.int64)

    # 保存编码映射表
    idx2class_path = os.path.join(data_path, 'idx2class.txt')
    idx2class_dict = {}
    for cl in le.classes_:
        idx2class_dict.update({le.transform([cl])[0]: cl})
    dict2txt(idx2class_path, idx2class_dict)

    # 划分
    train_save_path = os.path.join(data_path, 'train.txt')
    val_save_path = os.path.join(data_path, 'val.txt')
    test_save_path = os.path.join(data_path, 'test.txt')

    random_state = 2
    val_ratio = 0.2
    test_ratio = 0

    if test_ratio > 0:
        x_train, y_train, x_test, y_test, x_val, y_val = train_test_val_split(x, y, val_ratio, test_ratio, random_state)
        save2txt(test_save_path, x_test, y_test)
    else:
        x_train, y_train, x_val, y_val = train_val_split(x, y, val_ratio, random_state)

    save2txt(train_save_path, x_train, y_train)
    save2txt(val_save_path, x_val, y_val)



Transform类

Transform用于增强输入数据,常见的有归一化、随机裁剪、旋转、灰度化等

Dataset类

​模板如下:

class MyDataset(torch.utils.data.Dataset):#需要继承torch.utils.data.Dataset
    def __init__(self):
        #对继承自父类的属性进行初始化
        super(MyDataset,self).__init__()
        # TODO
        #1、初始化一些参数和函数,方便在__getitem__函数中调用。
        #2、制作__getitem__函数所要用到的图片和对应标签的list。
        #也就是在这个模块里,我们所做的工作就是初始化该类的一些基本参数。
        pass
    def __getitem__(self, index):
        # TODO
        #1、根据list从文件中读取一个数据(例如,使用numpy.fromfile,PIL.Image.open)。
        #2、预处理数据(例如torchvision.Transform)。
        #3、返回数据对(例如图像和标签)。
        #这里需要注意的是,这步所处理的是index所对应的一个样本。
        pass
    def __len__(self):
        #返回数据集大小
        return len()

完整代码如下:

import os.path
from datasets.uac_data_aug import *

import torch


def load_data(data_path):
    data = []
    label = []
    with open(data_path, 'r') as f:
        lines = f.readlines()
        for line in lines:
            line = line.split('\t')
            data.append(line[0])
            label.append(line[1].replace('\n', ''))
    return data, label


class MyDataset(torch.utils.data.Dataset):  # 需要继承torch.utils.data.Dataset
    def __init__(self, data_path, data_type, transform=None):
        # 根据data_type确定data_path
        if data_type == 'train':
            self.test = False
            self.data_path = os.path.join(data_path, 'train.txt')
        elif data_type == 'test':
            self.test = True
            self.data_path = os.path.join(data_path, 'test.txt')
        elif data_type == 'val':
            self.test = True
            self.data_path = os.path.join(data_path, 'val.txt')
        else:
            raise ValueError('Error Input Data Type!')

        # 加载数据
        self.data, self.label = load_data(self.data_path)

        # 根据transform对数据进行处理
        if transform is None:
            self.transforms = Compose([
                Reshape()
            ])
        else:
            self.transforms = transform

    def __getitem__(self, index):
        if self.test:
            seq = self.data[index]
            seq = self.transforms(seq)
            return seq, index
        else:
            seq = self.data[index]
            label = self.labels[index]
            seq = self.transforms(seq)
            return seq, label

    def __len__(self):
        # 返回数据集大小
        return len(self.data)


if __name__ == '__main__':
    data_path = r'.\data'
    data_type = 'val'
    mydataset= MyDataset(data_path, data_type)

Dataloader类

import torch.utils.data

train_data_path = r'.\data'
train_dataset = MyDataset(train_data_path, 'train')
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    shuffle=True,
    batch_size=args.train.batch_size
)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值