实战通过paddleNlp将富文本正文html中提取企业名称

业务需求

公司一数据库表中有600万条正文,格式是html,需要将html中的是企业名称的全部提取出来。

需要任务自动化,以至于可以24小时全程挂着跑

前提准备

需要带有一张英伟达计算卡的服务器,最好是Tesla的,最低Tesla V100,也可以是GTX系列

GPU需要有CUDA和Cudnn

paddle环境以及事先需要了解什么是paddle https://www.paddlepaddle.org.cn/

逻辑流程

1. 定时任务,每分钟启动一次,判断当前程序是否正在执行,如果是就跳过等待下一轮判断

2. 如果没有任务正在执行,启动任务,在表中记录任务时间节点和任务状态进行拦截

3. 分多线程执行逻辑

4. 所有线程结束后将任务状态还原

代码

python引入包(太多了,都是干啥的不说了,大多数都很常见)

from paddlenlp import Taskflow
from apscheduler.schedulers.blocking import BlockingScheduler
from bs4 import BeautifulSoup
import re
import pymysql
import threading
import datetime
import logging
import torch

主方法(这里job给最大两个进程是一个正在执行的任务,一个是每分钟判断任务进行的定时)

if __name__ == '__main__':
    try:
        # 定时任务,每两分钟执行一次
        scheduler = BlockingScheduler(timezone='Asia/Shanghai', job_defaults={'max_instances': 2})
        scheduler.add_job(task_start, 'interval', minutes=1)
        scheduler.start()
    except Exception as err1:
        logging.error(err1)

常用的查询和更新方法

# 方法:执行查询无参数
def execute_query(db_connection, query):
    cursor = db_connection.cursor()
    cursor.execute(query)
    result = cursor.fetchall()
    cursor.close()
    return result


# 方法:执行带参数的查询
def execute_query_with_params(db_connection, query, params):
    cursor = db_connection.cursor()
    cursor.execute(query, params)
    result = cursor.fetchall()
    cursor.close()
    return result


# 方法2:执行数据处理
def execute_insert_update(db_connection, sql, params):
    cursor = db_connection.cursor()
    cursor.execute(sql, params)
    db_connection.commit()
    cursor.close()


# 当前时间
def getNowDateTimeStr():
    today = datetime.datetime.now()
    return today.strftime(DATETIME_FORMAT)


# 方法:执行插入并获取主键
def execute_insert_and_get_primary_key(db_connection, sql, params):
    cursor = db_connection.cursor()
    cursor.execute(sql, params)
    db_connection.commit()
    # 获取最后插入的主键
    last_inserted_id = cursor.lastrowid
    cursor.close()
    return last_inserted_id

任务第一步 (这里做的自动化是根据数据表中id有序的主键id区间做任务,每500条一次任务)

# 任务开始,第一步
def task_start():
    try:
        db_local = create_db_connection()
        # 查询版本控制
        result1_list = execute_query(db_local, "SELECT id, task_deal_id, task_type FROM t_ggzy_deal_task ORDER BY id desc Limit 1")
        if result1_list is not None and len(result1_list) > 0:
            result1 = result1_list[0]
            result1_id = result1[0]
            # 上次任务的截止deal表的id
            result1_task_id = result1[1]
            # 任务进行状态
            result1_type = result1[2]
            # 0代表程序没在跑
            if result1_type == 0:
                new_id = result1_task_id + 500
                max_id = execute_query(db_local, "SELECT MAX(fdId) FROM t_ggzy_deal")
                if max_id[0][0] < new_id:
                    new_id = max_id[0][0]
                # 插入新任务
                insert_task_params = (new_id, getNowDateTimeStr(), 1)
                logging.info('开启本轮任务>>>>>>>>>>>>>数据处理id区间%s~%s>>>>%s' % (result1_task_id, new_id, getNowDateTimeStr()))
                last_id = execute_insert_and_get_primary_key(db_local, "INSERT INTO `cei_ggzy`.`t_ggzy_deal_task`(`task_deal_id`, `create_time`, `task_type`) VALUES (%s, %s, %s);", insert_task_params)
                # 查询数据区间 上一条记录id和当前id,基本上是3000条
                logging.info('开启查询本轮任务数据总量>>>>>>>>>>>>>%s' % getNowDateTimeStr())
                data_list = execute_query_with_params(db_local, "SELECT fdId, fdUUID FROM t_ggzy_deal WHERE fdId > %s and fdId <= %s", (result1_task_id, new_id))
                logging.info('完成查询本轮任务数据总量>>>>>>>>>>>总数为:%s>>%s' % (len(data_list), getNowDateTimeStr()))
                if len(data_list) > 0:
                    # 开启多线程跑数据
                    threadUtil(data_list, last_id)
        db_local.close()
    except Exception as err2:
        logging.error(err2)

主线程(这里注意torch.cuda.empty_cache()是清除GPU显存的方法)

# 多线程-父线程
def threadUtil(main_list, task_id):
    # 定义线程数量
    num_threads = 5
    # 计算每个线程要处理的子集大小
    subset_size = len(main_list) // num_threads
    # 创建线程对象列表
    threads = []
    # 创建并启动线程
    for i in range(num_threads):
        # 计算子集的起始和结束索引
        start = i * subset_size
        end = start + subset_size if i < num_threads - 1 else len(main_list)
        # 创建线程,并将子集作为参数传递给线程的执行函数
        t = threading.Thread(target=process_subset, args=(main_list[start:end], i))
        threads.append(t)
        t.start()
    # 等待所有线程执行完成
    for t in threads:
        t.join()
    # 打印处理后的列表
    print("所有线程结束了================")
    db_local = create_db_connection()
    # 任务结束后更新任务表状态
    execute_insert_update(db_local, "UPDATE t_ggzy_deal_task SET task_type = 0 WHERE id = %s", task_id)
    db_local.close()
    # 在任务运行之后
    torch.cuda.empty_cache()
    logging.info('完成本轮任务>>>>>>>>>>>>>%s' % getNowDateTimeStr())

子线程(这里是每个线程将html中内容提取企业名称的具体流程,涉及到数据如何清洗等步骤,为了数据准确性)

# 多线程-子线程-执行函数
def process_subset(itemList, threadName):
    try:
        print("线程【%s】一共需要处理的数量==%s" % (threadName, len(itemList)))
        count = len(itemList)
        ner = Taskflow("ner", entity_only=True, batch_size=5, gpu_id=0)
        # 遍历数据,插入到临时表
        for item in itemList:
            deal_id = item[0]
            deal_uuid = item[1]
            # 连接全国公共资源
            db_process = create_db_connection()
            if deal_uuid is not None and deal_uuid != "":
                content_list = execute_query_with_params(db_process, "SELECT fdHtmlContent FROM t_ggzy_deal_content WHERE fdUUID = %s", deal_uuid)
                if len(content_list) > 0:
                    if content_list[0][0] is not None and content_list[0][0] != "":
                        content = content_list[0][0]
                        # 解析HTML
                        soup = BeautifulSoup(content, 'html.parser')
                        # 找到所有的<img>标签并将其移除
                        for img_tag in soup.find_all('img'):
                            img_tag.extract()
                        # 提取所有文本内容
                        all_text = soup.get_text()
                        # 清洗数据格式,以保证NLP提取准确性
                        all_text = re.sub(r'\n+', '!', all_text)
                        all_text = re.sub(r' ', '!', all_text)
                        # 使用正则表达式将文本分成句子或段落,处理长文本
                        sentences = re.split(r'[。!?]', all_text)
                        # 拼接企业名称结果用,用来去重企业名称操作
                        new_data = ''
                        # 记录表记录解析成功失败的布尔值
                        flag = False
                        # 遍历分段后的文本
                        for sentence in sentences:
                            # 空字符串解析会导致NLP报错
                            if sentence != "":
                                # 使用NER Taskflow处理分段文本
                                result = ner(sentence)
                                # 遍历每段结果,结果是个list
                                for nei in result:
                                    neiName = nei[0].replace(" ", "")
                                    neiName = neiName.replace('\xa0', '')
                                    neiName = re.sub(r'\u3000+', '', neiName)
                                    neiType = nei[1]
                                    # 提取富文本中全部企业类型并且去重拼接
                                    if (neiType == '组织机构类_企事业单位' or neiType == '组织机构类_国家机关' or neiType == '组织机构类_体育组织机构' or neiType == '组织机构类_军事组织机构' or
                                        neiType == '组织机构类_医疗卫生机构' or neiType == '组织机构类_教育组织机构') and neiName != "" and neiName not in new_data:
                                        # 如果企业名称字数小于4个字,肯定不是真实企业
                                        if len(neiName) >= 4:
                                            if len(neiName) == 4 and neiName[-2:] == "单位":
                                                continue
                                            print("线程【%s】>>>>%s" % (threadName, neiName))
                                            # 拼接字符串,去重操作
                                            new_data = new_data + neiName + " "
                                            # 将企业名称和uuid绑定插入,过程中有判断是否存在的去重操作
                                            suspected_handle(db_process, deal_uuid, neiName)
                                            flag = True
                        if flag:
                            # 记录状态
                            record_handle(db_process, deal_id, deal_uuid, 1)
                        else:
                            record_handle(db_process, deal_id, deal_uuid, 3)
                    else:
                        # 记录状态
                        record_handle(db_process, deal_id, deal_uuid, 2)
                else:
                    # 记录状态
                    record_handle(db_process, deal_id, deal_uuid, 2)
            else:
                # 记录状态,避免重复插入记录,不做增量
                record_handle(db_process, deal_id, deal_uuid, 0)
            count -= 1
            print("线程%s========剩余%s" % (threadName, count))
            db_process.close()
    except Exception as err3:
        logging.error(err3)

以前调用的部分方法比如记录日志,判断数据更新等等就不发布了,涉及到项目具体内容。

可能存在问题以及解决

服务器GPU报内存溢出:减少paddleNlp数据批次batch的大小,减少启用的线程

paddleNlp未使用到GPU而是用的CPU:看看paddle安装的是不是GPU版本,如果是的话有没有安装GPU需要的驱动,最后检查Taskflow调用包的时候有没有加gpu_id=0

为什么解析结果不准确:html需要进行数据格式清洗,清洗流程最好将空格,换行等替换成逗号,能增加准确性

单线程很准,多线程结果不准:这种就是可能Taskflow加载包的时候是全部线程共用的一个包,最好把加载步骤写在子线程里,每个线程用自己的包对象

怎么检查是否调用了GPU:CPU和GPU执行NLP的时间天差地别,基本上500条正文用CPU可能需要20分钟甚至半个小时,GPU只需要两分钟,看速度就知道了

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值