大规模未标记的文本数据分类处理baseline

一、问题介绍

这里是华为的一个文本分类比赛,数据量大,而且有很多文章并没有标记类别。基础数据集包含两部分:训练集和测试集。其中训练集给定了该样本的文章质量的相关标签,测试集用来测试模型的标签预测准确率。在这里插入图片描述
该文本分类的难点主要有两个,一、文章的长度比较长,属于长文本分类,而Bert的最大输入只有512.二、训练集中有大量的未标记数据,而且还包含了“类别”为其他的文本,但是没有标记出来。所以对测试集分类的时候,也要考虑文章类别为“其他”的情况。
以下给出训练集中标签和文本数量的数据情况,’ ‘表示未标记数据,对应的文章可能有类型,也可能类型为“其他”。
训练集有576454条文本数据,只有76454条有标签。
{’ ': 500000, ‘人物专栏’: 7242, ‘情感解读’: 7183, ‘科普知识文’: 6337, ‘攻略文’: 5517, ‘物品评测’: 4381, ‘治愈系文章’: 3868, ‘推荐文’: 1194, ‘深度事件’: 16670, ‘作品分析’: 14094, ‘行业解读’: 9968}

二、Baseline解决思路

(1)对数据预处理,对原始训练集、测试集进行简单清洗,处理训练集得到带标记样本集(P)和未知标记样本集(U)。
(2)使用P训练Bert分类器(10分类)
(3)使用已训练的Bert分类器预测U,并输出可靠负样本集(RN),文章类型为“其他”。
(4)使用P和RN训练二分类器
(5)使用二分类器和Bert分类器联合预测测试集上样本的类别,并保存文章id和预测的标签至文件

简单提一下这里的难点,就是获取可靠负样本集(RN),Baseline采用了随机森林的机器学习方法,也可以采用朴素贝叶斯的方法。这一步获取的负样本影响到二分类的训练,二分类训练器主要用来判别文章类型为“其他”的文章。

三、代码

1、Config.py

给出数据、模型等文件目录

# The path to the directory which stores all the datasets.
# If equals "", the path will set to the path to the directory which stores the datasets downloaded from Digix website.
BASE_DATASET_PATH = "./data"
# The path to the directory which stores all the trained model files.
# If equals "", the path will set to the path to the sub directory "model" at current directory
BASE_MODEL_PATH = "./model"
# The path to the pretrained BERT ENCODER, such as "/data/bert_base_chinese", you can download it at https://huggingface.co/bert-base-chinese/tree/main
PRETRAINED_BERT_ENCODER_PATH = "./预训练模型"
# The path to save the structured result(.csv) of evaluating test files.
# If equals "", the path will set to the file "submission.csv" at the current directory
SUMMARY_OUTPUT_PATH = "./summary"

2、main.py

执行训练并保存预测的结果,即处理的流程

from Preprocess import preprocess
from Build_PU_data import build_pu_data
from Train_Bert import train_bert
from Train_PU_model import train_pu_model
from Joint_Predictor import joint_predictor

if __name__ == "__main__":
    # 数据集预处理:对原始训练集,测试集进行简单的清洗,从训练集中输出 带标记样本集(P)和 未知样本集(U)
    preprocess()
    # 使用 P 训练 Bert分类器(10分类)
    train_bert()
    # 使用已训练的 Bert分类器预测 U,并输出 可靠负样本集(RN)
    build_pu_data()
    # 使用 P 和 RN 训练 二分类器
    train_pu_model()
    # 使用 Bert分类器 和 二分类器 联合预测 测试集上样本的类别,并格式化输出结果至文件
    joint_predictor()

3、Preprocess.py

数据处理,清洗训练数据和测试数据,并取出训练集中带标签的数据P和不带标签的数据U。数据格式如下:
在这里插入图片描述
下面给出处理的代码

import json
import os
import pandas as pd
import re
from tqdm import tqdm
from bs4 import BeautifulSoup
import Config

if Config.BASE_DATASET_PATH == "":
    curdir = os.path.dirname(os.path.abspath(__file__))
    dataset_path = os.path.join(curdir, "dataset")
    if not os.path.exists(dataset_path):
        os.mkdir(dataset_path)
else:
    dataset_path = Config.BASE_DATASET_PATH

RAW_TRAIN_FILE_PATH = os.path.join(dataset_path, "doc_quality_data_train_1000.json")
RAW_TEST_FILE_PATH = os.path.join(dataset_path, "doc_quality_data_test_1000.json")
PREPROCESSED_TRAIN_FILE_PATH = os.path.join(dataset_path, "preprocessed_train.json")
PREPROCESSED_TEST_FILE_PATH = os.path.join(dataset_path, "preprocessed_test.json")
POSITIVE_TRAIN_FILE_PATH = os.path.join(dataset_path, "postive_train.json")
POSITIVE_TRAIN_INFO_PATH = os.path.join(dataset_path, "positive_info.json")
UNLABELED_TRAIN_FILE_PATH = os.path.join(dataset_path, "unlabeled_train.json")
# 优质类别索引列表
INDEX = ['人物专栏', '作品分析', '情感解读', '推荐文', '攻略文', '治愈系文章', '深度事件', '物品评测', '科普知识文', '行业解读']

# 获取数据集的标签集及其大小
def get_label_set_and_sample_num(config_path, sample_num=False):
    with open(config_path, "r", encoding="UTF-8") as input_file:
        json_data = json.loads(input_file.readline())
        if sample_num:
            return json_data["label_list"], json_data["total_num"]
        else:
            return json_data["label_list"]


# 生成数据集对应的标签集以及样本总数
def build_label_set_and_sample_num(input_path, output_path):
    label_set = set()
    sample_num = 0
    
    with open(input_path, 'r', encoding="utf-8") as input_file:
        for line in tqdm(input_file):
            json_data = json.loads(line)
            label_set.add(json_data["label"])
            sample_num += 1
            
    with open(output_path, "w", encoding="UTF-8") as output_file:
        record = {
   "label_list": sorted(list(label_set)), "total_num": sample_num}
        json.dump(record, output_file, ensure_ascii=False)

        return record["label_list"], record["total_num"]


def get_sentences_list(raw_text: str):
    #BeautifulSoup对象,参数 文档字符串,html解析器,文档编码
    return [s for s in BeautifulSoup(raw_text, 'html.parser')._all_strings()]


def check_length(length_list):
    #sum对列表的元素求和
    sum_length = sum(length_list)
    if sum_length < 510:
        return sum_length
    return 510


# 去除空白字符, 从数据集遍历代码中移至此处
def remove_symbol(string: str):
    return string.replace('\t', '').replace('\n', '').replace('\r', '')

#这一步主要是解决一部分标题在文本中也出现了的情况,因为训练Bert时是取标题加上文本开头的部分。不超过512
def check_duplicate_title(input_path, output_path):
    duplicate = 0
    no_html = 0
    no_duplicate = 0
    print("Processing File: ", input_path)
    with open(input_path, "r", encoding='utf-8') as file, open(output_path, "w", encoding="utf-8") as outfile:
        for line in tqdm(file):
            json_data = json.loads(line)
            title = json_data["title"]
            body = get_sentences_list(json_data["body"])
            title_length = len(title)

            # 正文中不含HTML标签
            if len(body) == 1:
                no_html += 1
                tmp_body = body[0]
                # 注意,这边re.sub的pattern使用了re.escape()
                # 是为了转译title中存在的会被re视为元字符的字符(例如"?"","*")
                # 事实上相当于"\".join(title)[将所有字符转译为普通字符]
                new_body = re.sub("(原标题:)?" + re.escape(title), "", tmp_body)
                new_body_length = len(new_body)

                if new_body_length == len(tmp_body):
                    no_duplicate += 1
                else:
                    duplicate += 1

            # 正文中包含HTML标签
            else:
                i = 0
                # 检查 标题是否出现在前两个元素中 (有可能存在标签<p class=\"ori_titlesource\">,会有"原标题: title"的情况出现)
                for sentence in body[:2]:
                    if title in sentence:
                        i += 1

                new_body = "".join(body[i:])

                if i > 0:
                    duplicate += 1
                else:
                    no_duplicate += 1

            rm_whites_body = remove_symbol(new_body)
            rm_whites_title = remove_symbol(title)

            json_data["body"] = rm_whites_body
            json_data["title"] = rm_whites_title
            json_data["length"] = check_length([len(rm_whites_body), len(rm_whites_title)])
            json.dump(json_data, outfile, ensure_ascii=False)
            outfile.write("\n")

    print("duplicate: {}\t no_html: {}, no_duplicate: {}\n".format(duplicate, no_html, no_duplicate))


def index_data_pd(index, input_path, output_path1, output_path2):
    print(input_path)
    df_data = pd.read_json(input_path, orient="records", lines=True)
    # 处理已标注数据
    df_data_labeled = df_data[df_data["doctype"
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值