python构造数据批次迭代

csv数据集内容

1	可以	 医生有微信吗
1	医生在吗?	好的,谢谢

from collections import Counter
from functools import reduce
import torch
import pandas as pd
from sklearn.utils import shuffle
from get_bert_encoder import get_bert_encoder


def data_loader(data_path, batch_size, split=0.2):

    data = pd.read_csv(data_path, header=None, sep='\t')
    
    print(dict(Counter(data[0].values)))
    #数据集打乱
    data = shuffle(data)
    
    split_poit = int(len(data) * split)
    valid_data = data[:split_poit]
    train_data = data[split_poit:]

    if len(valid_data) < batch_size:
        raise ("batch size not match")

    def _loader_generitor(dataer):

        for batch in range(0, len(dataer), batch_size):

            batch_encoder = []
            #append后batch_encoder数据存储格式
#[tensor([[[-0.4900,  0.6886,  0.3022,  ...,  0.8785,  0.7249,  0.1749],
#         [-0.0433, -0.2382,  0.3189,  ..., -0.2429, -0.2036,  0.1164],
#         [ 0.5549, -0.4167, -0.4483,  ...,  0.5328,  0.4729,  0.5566],
#         ...,
#         [-0.0201,  0.0176,  0.7259,  ...,  1.0028, -0.1844,  0.0627],
#         [ 0.8113, -0.1479,  0.0588,  ...,  1.9879,  0.2245,  0.2495],
#         [ 1.4874,  1.2437,  0.1040,  ..., -0.1453,  0.4713, -0.0798]]]), #tensor([[[ 0.7351,  1.0412,  
#0.0930,  ..., -0.1135,  0.3518,  0.1267],         [ 0.3855,  0.4759,  #1.6490,  ..., -0.1589, -0.2300, -0.2439],
#         [ 0.6289, -0.1845, -0.2307,  ...,  1.1178,  0.9952,  0.3715],
#         ...,
#         [-0.2628,  0.3242,  0.4993,  ...,  0.6517,  0.3507,  0.5001],
#         [ 1.0122,  0.4628,  0.5737,  ..., -0.1503,  0.2858, -0.0190],
#         [ 1.0487,  0.1006,  0.9643,  ..., -0.6839,  0.0644, -0.0096]]]), #tensor([[[ 0.5285,  0.1498,  
#0.4196,  ..., -0.2953,  0.0031, -0.2331],         [-0.3692,  0.6364, #-0.1482,  ..., -0.7598, -0.4572, -0.2297],
#         [ 0.4455,  0.1373,  0.0463,  ...,  0.4425,  0.0456, -0.1619],
#         ...,
#         [ 0.3709,  0.3329, -1.6963,  ...,  1.7455, -0.7867,  0.5836],
#         [ 0.2228,  0.5823, -0.3464,  ...,  0.7562, -0.2659, -0.2854],
#         [ 0.8495,  0.6451,  0.1935,  ...,  0.2296, -0.9727, -0.7391]]])] 

            batch_labels = []
#[[0], [1], [1], [0], [0], [0], [0], [0], [1], [0], [1][0], [0], [0], [1], #[0], [0], [1], [0], [
#1], [0], [0], [0], [1], [1], [1], [1], [1], [1], [1], [0], [1]] """

            for item in dataer[batch:batch + batch_size].values.tolist():
            #将array数据转换成list
            # item的格式['1','可以','医生有微信吗']
                en = get_bert_encoder(item[1], item[2])
                batch_encoder.append(en)
                batch_labels.append([item[0]])

            encoder = reduce(lambda x, y: torch.cat((x, y), dim=0), batch_encoder)
            #encoder转换后的个[[[222,]],
            #[[3333333
            #]]]
            labels = torch.tensor(reduce(lambda x, y: x + y, batch_labels))

            yield (encoder, labels)

    return _loader_generitor(train_data), _loader_generitor(valid_data), len(train_data), len(valid_data)


batch_size = 32
data_path = '/root/mmm/train_data.csv'
train_data_labels, valid_data_labels, train_data_len, valid_data_len = data_loader(data_path, batch_size)
print(next(valid_data_labels))

一个batch_size数据


(tensor([[[-1.0072e+00,  1.1102e+00,  5.6757e-01,  ...,  9.0415e-01,
           5.9650e-01,  7.8008e-01],
         [ 1.1987e-01,  6.7578e-02,  1.3482e-01,  ..., -1.7363e-01,
          -9.1816e-01,  1.4676e-01],
         [ 6.4202e-01, -6.2847e-01, -1.2118e+00,  ...,  1.4147e+00,
          -7.8515e-01,  7.8756e-01],
         ...,
         [ 2.0897e-01,  6.6435e-01, -1.2239e+00,  ...,  1.7984e-01,
           9.8415e-01, -1.3609e-01],
         [ 5.9390e-01,  8.5917e-01, -5.5631e-01,  ...,  7.0059e-01,
           8.4440e-01,  1.8956e-01],
         [ 5.5248e-02,  5.0485e-01, -6.2609e-01,  ...,  2.9628e-01,
           9.1855e-01, -6.5408e-03]],

        [[-7.0508e-01,  1.5962e+00,  1.2962e-01,  ...,  1.4066e+00,
           3.5016e-01,  2.8185e-01],
         [ 2.2448e-01, -4.4541e-01,  1.1292e-01,  ...,  3.8654e-01,
          -6.4893e-01,  3.6159e-01],
         [ 2.5746e-01, -3.0261e-01, -1.6573e-01,  ...,  1.1788e+00,
          -7.0541e-02,  6.6765e-01],
         ...,
         [ 1.4351e-01, -6.8308e-01,  4.4039e-01,  ...,  9.0505e-01,
           5.3005e-01, -5.3124e-01],
         [-8.0499e-01,  6.5483e-02,  2.9044e-01,  ...,  8.9630e-01,
           5.1957e-01, -3.7248e-01],
         [ 3.3567e-01,  2.3897e-01,  2.0454e-01,  ..., -3.4811e-01,
           5.5508e-01, -6.2328e-01]],

        [[ 9.0099e-01,  4.3394e-01, -3.2617e-01,  ...,  1.0812e-01,
          -1.6899e-01, -3.9596e-01],
         [ 1.4387e-01,  9.1589e-02, -4.7001e-01,  ..., -2.8069e-01,
          -3.7368e-01, -1.6904e-01],
         [ 1.5482e+00,  2.2293e-01, -8.1567e-01,  ...,  1.9300e+00,
           3.6109e-02,  6.4605e-01],
         ...,
         [ 1.2336e+00, -4.2268e-01,  2.0106e-01,  ...,  5.9081e-01,
           1.3714e+00, -6.4655e-01],
         [ 8.4713e-01, -7.3470e-01,  4.8829e-01,  ...,  5.1511e-01,
           7.1544e-01,  2.7500e-01],
         [ 1.4288e+00, -8.7154e-01, -1.9726e-01,  ...,  1.8973e-01,
           3.1826e-01,  1.6991e-01]],

        ...,

        [[ 4.3038e-02,  1.3970e-01,  8.0579e-01,  ...,  4.8711e-01,
           5.4756e-01,  6.6365e-02],
         [-7.4814e-01,  1.0863e+00,  5.4144e-01,  ...,  6.9972e-02,
           5.5441e-02, -5.3881e-01],
         [ 2.9469e-01,  1.5886e-01,  3.0927e-01,  ...,  3.8557e-01,
          -3.1098e-01, -4.0174e-02],
         ...,
         [-5.4478e-01,  2.8056e-01, -4.9826e-01,  ...,  5.6205e-01,
           1.0145e+00, -3.5664e-01],
         [ 8.9176e-01, -7.0300e-01,  1.9767e-01,  ...,  1.1309e+00,
           4.9913e-01, -2.7993e-01],
         [-2.4874e-02, -1.7586e-01,  8.5407e-01,  ...,  4.5933e-01,
           1.2455e+00, -4.3745e-01]],

        [[-5.6448e-01,  1.2819e+00,  9.1512e-01,  ...,  4.3023e-01,
           1.3488e+00, -3.2839e-01],
         [ 5.3092e-01,  3.9522e-01,  2.0146e-01,  ..., -4.3103e-01,
          -1.3318e+00,  5.1111e-01],
         [ 3.1350e-01, -1.4403e-01, -4.2147e-01,  ...,  4.6097e-01,
          -4.1791e-01,  5.1731e-01],
         ...,
         [ 1.1909e+00,  4.3437e-02, -8.8326e-02,  ...,  6.6578e-01,
           2.9038e-01,  2.8392e-01],
         [ 9.2986e-01,  2.6484e-02, -7.3529e-01,  ...,  1.6857e+00,
           1.2189e+00,  1.8110e-01],
         [ 2.7756e-02,  5.7020e-01,  1.1896e-01,  ...,  6.3963e-01,
           2.6242e-02, -4.8905e-01]],

        [[-6.6924e-01,  1.0480e+00,  6.3848e-01,  ...,  5.3326e-01,
           5.5357e-01,  1.0924e-01],
         [-1.8421e-01,  1.4540e-03, -8.1378e-03,  ..., -1.6662e-01,
          -7.5137e-01, -1.4550e-01],
         [ 1.5616e+00, -2.4044e-01, -5.0204e-01,  ...,  1.6142e+00,
          -1.5740e-01,  4.8887e-01],
         ...,
         [ 8.0294e-01,  2.6990e-02,  2.1277e-01,  ...,  7.0323e-01,
           8.1737e-01,  8.5575e-03],
         [ 4.8750e-01,  7.2647e-01,  9.3661e-01,  ...,  1.7896e-01,
           8.2719e-01, -5.8381e-02],
         [ 7.3526e-01,  2.8660e-01,  4.5617e-01,  ...,  9.6911e-02,
           3.9665e-01,  2.8705e-01]]]), tensor([0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1
, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0]))

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值