【NLP挑战赛】:1、基于sklearn的train数据拆分为train&dev

如题,在训练一些模型的时候,只有train数据,需要我们手动拆分为train&dev,来进行模型的验证。这里使用一个sklearn的简单方法。

import logging
import pandas as pd
from sklearn.model_selection import train_test_split

logging.basicConfig(level=logging.INFO, format='%(asctime)-15s %(levelname)s: %(message)s')


def all_data2fold(data_file, train_dir, test_dir):
    f = pd.read_csv(data_file, sep='\t', encoding='UTF-8')
    print(f.columns)
    print(f.head())

    print(f["text_a"].shape)
    print(f["text_b"].shape)
    print(f["label"].shape)

    f["text"] = f["text_a"] + "\t" + f["text_b"]
    x = f["text"]
    y = f["label"]
    print("总的每个label的数量:\n", pd.value_counts(y))

    # print(f["text"].shape)
    # print(pd.value_counts(y))

    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)

    print(x_train.shape)
    print(x_test.shape)
    print(y_train.shape)
    print(y_test.shape)
    print("train的每个label的数量:\n", pd.value_counts(y_train))
    print("dev的每个label的数量:\n", pd.value_counts(y_test))

    with open(train_dir, "w", encoding="utf-8") as fout:
        for text, label in zip(x_train.tolist(), y_train.tolist()):
            fout.writelines(text + "\t" + str(label) + "\n")
    logging.info("save  " + train_dir +" done !")

    with open(test_dir, "w", encoding="utf-8") as fout:
        for text, label in zip(x_test.tolist(), y_test.tolist()):
            fout.writelines(text + "\t" + str(label) + "\n")
    logging.info("save  " + test_dir +" done !")

    return


if __name__ == "__main__":
    data_file = 'data/chinese_question_sim/train.csv'
    train_dir = "data/chinese_question_sim/train_split_data/train_data.csv"
    test_dir = "data/chinese_question_sim/train_split_data/dev_data.csv"
    all_data2fold(data_file, train_dir, test_dir)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值