简介
记录深度学习编写程序过程中的一些工具函数
数据集
划分数据集
数据集划分思路:
- 若数据集很小,直接随机打乱
import random random.shuffle(data)
- 若数据集很大,选择随机打乱下标,根据下标实现数据集划分
-
get_dataset_split_num
无需输入训练集,只输入验证集和测试集的比例或具体数量
def get_dataset_split_num(n, valid=0, test=0): """ n: 数据集数量 valid, test: 可为比例和具体数值 """ if valid < 1: assert test < 1 assert valid + test > 0 valid_num = int(n * valid) test_num = int(n * test) train_num = n - valid_num - test_num else: valid_num = valid test_num = test train_num = n - valid_num - test_num return train_num, valid_num, test_num
运行:
train_num, valid_num, test_num = get_dataset_split_num(100, valid=0.2, test=0.31) train_num, valid_num, test_num = get_dataset_split_num(100, valid=20, test=31)
-
cut_datasets
数据集打乱def cut_datasets(arr, valid=0, test=0): """ arr: 为下标数组 """ train_num, valid_num, _ = get_dataset_split_num(len(arr), valid, test) a1 = arr[:train_num] a2 = arr[train_num:train_num + valid_num] a3 = arr[train_num + valid_num:] return a1, a2, a3