数据生成器迭代问题

数据生成器迭代问题


def data_loader(data_path, batch_size, split=0.2):
    """
    description: 从持久化文件中加载数据, 并划分训练集和验证集及其批次大小
    :param data_path: 训练数据的持久化路径
    :param batch_size: 训练和验证数据集的批次大小
    :param split: 训练集与验证的划分比例
    :return: 训练数据生成器, 验证数据生成器, 训练数据数量, 验证数据数量
    """
    data = []
    # 使用pd进行csv数据的读取
    # data = pd.read_csv(data_path, header=None, sep="\t")
    # data = pd.read_csv(data_path, header=None, sep="\t")

    # data_path = './dev.tsv'
    data_1 = pd.read_csv(data_path, header=None, sep="\t",error_bad_lines=False)
    # data_1 = pd.read_csv(data_path, header=None, sep="  ")


    # for ii in range(len(data_1[0])):
    #
    #     data_temp = []
    #     if len(data_1[3][ii]) < 80:
    #         if len(str(data_1[4][ii])) < 80:
    #             data_temp.append(data_1[0][ii])
    #             data_temp.append(data_1[3][ii])
    #             data_temp.append(data_1[4][ii])
    #
    #             data.append(data_temp)


    # data_path = './sts-dev.tsv'
    data_path1 = './sts-train.tsv'
    data_1 = pd.read_csv(data_path1, header=None, sep="\t",error_bad_lines=False)

    # data = []
    for ii in range(len(data_1[4])):

        data_temp = []
        if len(data_1[5][ii]) < 80:
            if len(str(data_1[6][ii])) < 80:
                if data_1[4][ii] < 2.5:
                    score1 = 0
                else:
                    score1 = 1
                # data_temp.append(data_1[4][ii])
                data_temp.append(score1)
                data_temp.append(data_1[5][ii])
                data_temp.append(data_1[6][ii])

                data.append(data_temp)

    # for i in data:
    #     print(i)
    #
    # data = pd.DataFrame(data, columns=[0, 1, 2])
    # print(dict(Counter(data[0].values)))


    # data=pd.array(data)
    # data=pd.DataFrame(data)
    data = pd.DataFrame(data, columns=[0, 1, 2])
    data.dropna(axis='index', inplace=True)
    # 打印整体数据集上的正负样本数量
    print("数据集的正负样本数量:")
    print(dict(Counter(data[0].values)))

    # 打乱数据集的顺序
    data = shuffle(data).reset_index(drop=True)
    # 划分训练集和验证集
    split_point = int(len(data)*split)
    valid_data = data[:split_point]
    train_data = data[split_point:]

    # 验证数据集中的数据总数至少能够满足一个批次
    if len(valid_data) < batch_size:
        raise("Batch size or split not match!")


    def _loader_generator(data):
        """
        description: 获得训练集/验证集的每个批次数据的生成器
        :param data: 训练数据或验证数据
        :return: 一个批次的训练数据或验证数据的生成器
        """
        # 以每个批次的间隔遍历数据集
        for batch in range(0, len(data), batch_size):
            # 预定于batch数据的张量列表
            batch_encoded = []
            batch_labels = []
            # 将一个bitch_size大小的数据转换成列表形式,[[label, text_1, text_2]]
            # 并进行逐条遍历
            for item in data[batch: batch+batch_size].values.tolist():
                # 每条数据中都包含两句话, 使用bert中文模型进行编码
                encoded = get_bert_encode(item[1], item[2])
                # encoded = get_bert_encode(item[3], item[4])
                # encoded = get_bert_encode(item[3], item[4])
                # 将编码后的每条数据装进预先定义好的列表中
                batch_encoded.append(encoded)
                # 同样将对应的该batch的标签装进labels列表中
                batch_labels.append([item[0]])
            # 使用reduce高阶函数将列表中的数据转换成模型需要的张量形式
            # encoded的形状是(batch_size, 2*max_len, embedding_size),dim=0表示按照batchsize维度拼接
            encoded = reduce(lambda x, y : torch.cat((x, y), dim=0), batch_encoded)
            # labels = torch.tensor(reduce(lambda x, y : x + y, batch_labels))
            labels = torch.tensor(reduce(lambda x, y : x + y, batch_labels))
            #x+y,拼接列表
            # 以生成器的方式返回数据和标签
            yield (encoded, labels)

    # 对训练集和验证集分别使用_loader_generator函数, 返回对应的生成器
    # 最后还要返回训练集和验证集的样本数量
    return _loader_generator(train_data), _loader_generator(valid_data), len(train_data), len(valid_data)



评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值