首先说明,这是学习过程中的服务于自己的笔记,如有问题,欢迎大家批评指正,不胜感激
def slice_data(data, length=864, number=1000, slice_rate=[0.7, 0.2, 0.1], enc=True, enc_step=28): train_samples = {} #用于存放所有文件的切片结果,每个文件的样本对应字典中一个元素,作为训练集 test_and_valid_samples = {} #用于存放训练集 keys = data.keys() #获取字典的所有键值 #通过遍历健实现遍历所有数据,先采集训练样本,后采集测试样本和验证样本 for key in keys: train_data = [] #临时存储切片得到的训练样本 total_length = len(data[key]) #数据的总长度 end_index = int(total_length * slice_rate[0]) #采集训练样本时,只能采集到前面训练样本所占比例,相当于把整个数据集也分为了训练集、测试集、验证集,至于开始采样的位置由随机数确定 train_num = int(number * slice_rate[0]) #获取训练样本总数量 一定要强制类型转换为整数,不然后面使用==判断就会出现问题 if enc: enc_times = length // enc_step #在一个length里需要增强采集的次数 steps = 0 #记录采样次数,采样到达train_num停止 for j in range(train_num): label = False #标志位,用于确定是