csv数据集内容
1 可以 医生有微信吗
1 医生在吗? 好的,谢谢
from collections import Counter
from functools import reduce
import torch
import pandas as pd
from sklearn.utils import shuffle
from get_bert_encoder import get_bert_encoder
def data_loader(data_path, batch_size, split=0.2):
data = pd.read_csv(data_path, header=None, sep='\t')
print(dict(Counter(data[0].values)))
#数据集打乱
data = shuffle(data)
split_poit = int(len(data) * split)
valid_data = data[:split_poit]
train_data = data[split_poit:]
if len(valid_data) < batch_size:
raise ("batch size not match")
def _loader_generitor(dataer):
for batch in range(0, len(dataer), batch_size):
batch_encoder = []
#append后batch_encoder数据存储格式
#[tensor([[[-0.4900, 0.6886, 0.3022, ..., 0.8785, 0.7249, 0.1749],
# [-0.0433, -0.2382, 0.3189, ..., -0.2429, -0.2036, 0.1164],
# [ 0.5549, -0.4167, -0.4483, ..., 0.5328, 0.4729, 0.5566],
# ...,
# [-0.0201, 0.0176, 0.7259, ..., 1.0028, -0.1844, 0.0627],
# [ 0.8113, -0.1479, 0.0588, ..., 1.9879, 0.2245, 0.2495],
# [ 1.4874, 1.2437, 0.1040, ..., -0.1453, 0.4713, -0.0798]]]), #tensor([[[ 0.7351, 1.0412,
#0.0930, ..., -0.1135, 0.3518, 0.1267], [ 0.3855, 0.4759, #1.6490, ..., -0.1589, -0.2300, -0.2439],
# [ 0.6289, -0.1845, -0.2307, ..., 1.1178, 0.9952, 0.3715],
# ...,
# [-0.2628, 0.3242, 0.4993, ..., 0.6517, 0.3507, 0.5001],
# [ 1.0122, 0.4628, 0.5737, ..., -0.1503, 0.2858, -0.0190],
# [ 1.0487, 0.1006, 0.9643, ..., -0.6839, 0.0644, -0.0096]]]), #tensor([[[ 0.5285, 0.1498,
#0.4196, ..., -0.2953, 0.0031, -0.2331], [-0.3692, 0.6364, #-0.1482, ..., -0.7598, -0.4572, -0.2297],
# [ 0.4455, 0.1373, 0.0463, ..., 0.4425, 0.0456, -0.1619],
# ...,
# [ 0.3709, 0.3329, -1.6963, ..., 1.7455, -0.7867, 0.5836],
# [ 0.2228, 0.5823, -0.3464, ..., 0.7562, -0.2659, -0.2854],
# [ 0.8495, 0.6451, 0.1935, ..., 0.2296, -0.9727, -0.7391]]])]
batch_labels = []
#[[0], [1], [1], [0], [0], [0], [0], [0], [1], [0], [1][0], [0], [0], [1], #[0], [0], [1], [0], [
#1], [0], [0], [0], [1], [1], [1], [1], [1], [1], [1], [0], [1]] """
for item in dataer[batch:batch + batch_size].values.tolist():
#将array数据转换成list
# item的格式['1','可以','医生有微信吗']
en = get_bert_encoder(item[1], item[2])
batch_encoder.append(en)
batch_labels.append([item[0]])
encoder = reduce(lambda x, y: torch.cat((x, y), dim=0), batch_encoder)
#encoder转换后的个[[[222,]],
#[[3333333
#]]]
labels = torch.tensor(reduce(lambda x, y: x + y, batch_labels))
yield (encoder, labels)
return _loader_generitor(train_data), _loader_generitor(valid_data), len(train_data), len(valid_data)
batch_size = 32
data_path = '/root/mmm/train_data.csv'
train_data_labels, valid_data_labels, train_data_len, valid_data_len = data_loader(data_path, batch_size)
print(next(valid_data_labels))
一个batch_size数据
(tensor([[[-1.0072e+00, 1.1102e+00, 5.6757e-01, ..., 9.0415e-01,
5.9650e-01, 7.8008e-01],
[ 1.1987e-01, 6.7578e-02, 1.3482e-01, ..., -1.7363e-01,
-9.1816e-01, 1.4676e-01],
[ 6.4202e-01, -6.2847e-01, -1.2118e+00, ..., 1.4147e+00,
-7.8515e-01, 7.8756e-01],
...,
[ 2.0897e-01, 6.6435e-01, -1.2239e+00, ..., 1.7984e-01,
9.8415e-01, -1.3609e-01],
[ 5.9390e-01, 8.5917e-01, -5.5631e-01, ..., 7.0059e-01,
8.4440e-01, 1.8956e-01],
[ 5.5248e-02, 5.0485e-01, -6.2609e-01, ..., 2.9628e-01,
9.1855e-01, -6.5408e-03]],
[[-7.0508e-01, 1.5962e+00, 1.2962e-01, ..., 1.4066e+00,
3.5016e-01, 2.8185e-01],
[ 2.2448e-01, -4.4541e-01, 1.1292e-01, ..., 3.8654e-01,
-6.4893e-01, 3.6159e-01],
[ 2.5746e-01, -3.0261e-01, -1.6573e-01, ..., 1.1788e+00,
-7.0541e-02, 6.6765e-01],
...,
[ 1.4351e-01, -6.8308e-01, 4.4039e-01, ..., 9.0505e-01,
5.3005e-01, -5.3124e-01],
[-8.0499e-01, 6.5483e-02, 2.9044e-01, ..., 8.9630e-01,
5.1957e-01, -3.7248e-01],
[ 3.3567e-01, 2.3897e-01, 2.0454e-01, ..., -3.4811e-01,
5.5508e-01, -6.2328e-01]],
[[ 9.0099e-01, 4.3394e-01, -3.2617e-01, ..., 1.0812e-01,
-1.6899e-01, -3.9596e-01],
[ 1.4387e-01, 9.1589e-02, -4.7001e-01, ..., -2.8069e-01,
-3.7368e-01, -1.6904e-01],
[ 1.5482e+00, 2.2293e-01, -8.1567e-01, ..., 1.9300e+00,
3.6109e-02, 6.4605e-01],
...,
[ 1.2336e+00, -4.2268e-01, 2.0106e-01, ..., 5.9081e-01,
1.3714e+00, -6.4655e-01],
[ 8.4713e-01, -7.3470e-01, 4.8829e-01, ..., 5.1511e-01,
7.1544e-01, 2.7500e-01],
[ 1.4288e+00, -8.7154e-01, -1.9726e-01, ..., 1.8973e-01,
3.1826e-01, 1.6991e-01]],
...,
[[ 4.3038e-02, 1.3970e-01, 8.0579e-01, ..., 4.8711e-01,
5.4756e-01, 6.6365e-02],
[-7.4814e-01, 1.0863e+00, 5.4144e-01, ..., 6.9972e-02,
5.5441e-02, -5.3881e-01],
[ 2.9469e-01, 1.5886e-01, 3.0927e-01, ..., 3.8557e-01,
-3.1098e-01, -4.0174e-02],
...,
[-5.4478e-01, 2.8056e-01, -4.9826e-01, ..., 5.6205e-01,
1.0145e+00, -3.5664e-01],
[ 8.9176e-01, -7.0300e-01, 1.9767e-01, ..., 1.1309e+00,
4.9913e-01, -2.7993e-01],
[-2.4874e-02, -1.7586e-01, 8.5407e-01, ..., 4.5933e-01,
1.2455e+00, -4.3745e-01]],
[[-5.6448e-01, 1.2819e+00, 9.1512e-01, ..., 4.3023e-01,
1.3488e+00, -3.2839e-01],
[ 5.3092e-01, 3.9522e-01, 2.0146e-01, ..., -4.3103e-01,
-1.3318e+00, 5.1111e-01],
[ 3.1350e-01, -1.4403e-01, -4.2147e-01, ..., 4.6097e-01,
-4.1791e-01, 5.1731e-01],
...,
[ 1.1909e+00, 4.3437e-02, -8.8326e-02, ..., 6.6578e-01,
2.9038e-01, 2.8392e-01],
[ 9.2986e-01, 2.6484e-02, -7.3529e-01, ..., 1.6857e+00,
1.2189e+00, 1.8110e-01],
[ 2.7756e-02, 5.7020e-01, 1.1896e-01, ..., 6.3963e-01,
2.6242e-02, -4.8905e-01]],
[[-6.6924e-01, 1.0480e+00, 6.3848e-01, ..., 5.3326e-01,
5.5357e-01, 1.0924e-01],
[-1.8421e-01, 1.4540e-03, -8.1378e-03, ..., -1.6662e-01,
-7.5137e-01, -1.4550e-01],
[ 1.5616e+00, -2.4044e-01, -5.0204e-01, ..., 1.6142e+00,
-1.5740e-01, 4.8887e-01],
...,
[ 8.0294e-01, 2.6990e-02, 2.1277e-01, ..., 7.0323e-01,
8.1737e-01, 8.5575e-03],
[ 4.8750e-01, 7.2647e-01, 9.3661e-01, ..., 1.7896e-01,
8.2719e-01, -5.8381e-02],
[ 7.3526e-01, 2.8660e-01, 4.5617e-01, ..., 9.6911e-02,
3.9665e-01, 2.8705e-01]]]), tensor([0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1
, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0]))