使用BERT对搜狐新闻文本进行分类(一)

使用BERT对搜狐新闻文本进行分类(一)

本人小白,刚刚接触NLP领域,通过写博客方式来记录写一下学习的过程。如果有文章中有不对或可以改善指出,希望大家可以帮我指出。谢谢大家。

处理搜狐新闻数据文本

数据集来源:http://www.sogou.com/labs/resource/cs.php,下载的是完整版648M。

1.这里直接使用解压后的文件“news_sohusite_xml.dat”(推荐先使用小的文本进行试验处理)
'''
由于一些转码问题,我们先通过二进制方式读取,转换为gbk格式
'''
SOURCE_PATH = 'data/news_sohusite_xml.smarty.dat'
with open(SOURCE_PATH,'rb') as f:
    source_data = f.read().decode('gbk',"ignore")
2.接下来需要对格式进行转换。

可以看到一般来说新闻的类别隐藏在url域中,如gongyi对应公益。对应新闻的题目,content对应新闻的正文。

这里使用python的re模块来对内容进行提取,在这里我将类别,标题,正文都提取取出来,在后面的训练时可以选择是否对标题内容进行训练。

数据格式为

<doc>

<url>页面URL</url>

<docno>页面ID</docno>

<contenttitle>页面标题</contenttitle>

<content>页面内容</content>

</doc>
#提取新闻的内容
def extract_class_extract(doc):
    #先对<url>进行切割 分割为 <url>与其后面的内容,选取其后面的内容,再对</url>进行切割,选取</url>前面的内容
    url = doc.split('<url>')[1].split('</url>')[0]
    # contexttitle同理
    contenttitle = doc.split('<contenttitle>')[1].split('</contenttitle>')[0]
    #context同理
    content = doc.split('<content>')[1].split('</content>')[0]
    #之后从url中提取类别,使用正则表达式,这里得到的是一个列表但其实只有一项,我们直接提取出来
    category = re.findall(r"http://(.*?).sohu.com/", url)
    category = category[0]
    #这样提取的信息可能存在子领域(如try.women,apple.it等,但我们目前是为了分类大领域(只要women或者it),所以需要进一步提取
    #以'.'分割后的最后一个为大领域
    category = category.split('.')[-1]
    return category,contenttitle,content
3.将提取的新闻内容分别放入不同的文件中保存
def write_category_file(category,content):
    dir_name = 'data/category/'
    if not os.path.exists(dir_name):
        os.mkdir(dir_name)
    file_name = category+'.txt'
    path = os.path.join(dir_name, file_name)
    with open(path, 'a', encoding='gbk') as f:
        #这里选择先将类别写入,并使用\t来分割正文内容
        f.write(category + '\t' + content + '\n')
4.这一部分的代码整合
'''
由于一些转码问题,我们先通过二进制方式读取,转换为gbk格式
'''
with open(SOURCE_PATH,'rb') as f:
    source_data = f.read().decode('gbk',"ignore")
    #以<doc>作为切割
    docs = source_data.split('<doc>\n')

    for doc in docs:
        if doc:
            category,contenttitle,content = extract_class_extract(doc)
            #这里先没有使用到contenttitle
            write_category_file(category,content)

对分类好的数据进行划分,划分为train,dev,test三个数据集

1.先查看一下我们各个数据集的数据量
'''
输出所有的数据集的个数
'''
def get_data_size(file_name):
    path = os.path.join(CATEGORY_PATH,file_name)
    with open(path,'rb') as f:
        data = f.read().decode('utf-8',"ignore")
        lines = data.split('\n')
        return len(lines)
'''
输出所有数据集的数据个数
'''
def get_all_data_size():
    #使用字典来保存每一个类型数据集的大小
    category_size = {}
    for file_name in os.listdir(CATEGORY_PATH):
        data_size = get_data_size(file_name)
        #将数据数量添加到字典中
        category_size[file_name] = data_size
    return category_size

category_size = get_all_data_size()
    for key in category_size.keys():
        print(' %s : %d ' % (key,category_size.get(key)))
 2008.txt : 843 
 2010.txt : 559 
 2012.txt : 32 
 astro.txt : 361 
 auto.txt : 138577 
 baobao.txt : 2694 
 book.txt : 6533 
 bschool.txt : 231 
 business.txt : 27490 
 campus.txt : 3 
 chihe.txt : 533 
 club.txt : 350 
 cul.txt : 1925 
 dm.txt : 30 
 expo2010.txt : 2 
 fund.txt : 5016 
 games.txt : 43 
 gd.txt : 1844 
 goabroad.txt : 1107 
 gongyi.txt : 240 
 green.txt : 522 
 health.txt : 23410 
 it.txt : 199872 
 korea.txt : 110 
 learning.txt : 13013 
 media.txt : 670 
 men.txt : 1095 
 money.txt : 10617 
 news.txt : 86053 
 roll.txt : 720958 
 s.txt : 8679 
 sh.txt : 1299 
 sports.txt : 44537 
 stock.txt : 52931 
 travel.txt : 2180 
 tuan.txt : 2 
 tv.txt : 1644 
 v.txt : 9 
 women.txt : 5883 
 yule.txt : 50139 

这里我们看到有一些分类文本中的数量极少,并且还存在着一些意义不明的内容,我们选择舍弃这些极少量的数据文本。

2.筛选数据

在这里我只选择含有500条数据以上的数据集。

category_size = get_all_data_size()
for key in list(category_size.keys()):
    size = category_size.get(key)
    if size < 500:
        del category_size[key]
        print(' %s 数据不足被舍弃 ' % (key))

for key in category_size.keys():
    size = category_size.get(key)
    print(' %s : %d ' % (key, size))

得到了我们留下的文本列表

 2012.txt 数据不足被舍弃 
 astro.txt 数据不足被舍弃 
 bschool.txt 数据不足被舍弃 
 campus.txt 数据不足被舍弃 
 club.txt 数据不足被舍弃 
 dm.txt 数据不足被舍弃 
 expo2010.txt 数据不足被舍弃 
 games.txt 数据不足被舍弃 
 gongyi.txt 数据不足被舍弃 
 korea.txt 数据不足被舍弃 
 tuan.txt 数据不足被舍弃 
 v.txt 数据不足被舍弃 
 2008.txt : 843 
 2010.txt : 559 
 auto.txt : 138577 
 baobao.txt : 2694 
 book.txt : 6533 
 business.txt : 27490 
 chihe.txt : 533 
 cul.txt : 1925 
 fund.txt : 5016 
 gd.txt : 1844 
 goabroad.txt : 1107 
 green.txt : 522 
 health.txt : 23410 
 it.txt : 199872 
 learning.txt : 13013 
 media.txt : 670 
 men.txt : 1095 
 money.txt : 10617 
 news.txt : 86053 
 roll.txt : 720958 
 s.txt : 8679 
 sh.txt : 1299 
 sports.txt : 44537 
 stock.txt : 52931 
 travel.txt : 2180 
 tv.txt : 1644 
 women.txt : 5883 
 yule.txt : 50139 
3.进行划分数据集并将其保存到对应的文件

整体而言,我们的数据集的数量还是比较可观的,所以我按照98:1:1的比例来划分训练集,验证集,和测试集。

由于我没有找到好的函数来帮助将文本txt文件随机划分为三个训练集,所以我采用比较笨的方法,先获取文件长度,使用random.shuffle对其打乱顺序,再根据比例将其中内容写入到三个文件中。

'''
数据集进行切割,并保存
整体而言,我们的数据集的数量还是比较可观的,所以我按照98:1:1的比例来划分训练集,验证集,和测试集
'''
def save_file(category_size):
    SAVE_PATH = 'data/dataset/'
    if not os.path.exists(SAVE_PATH):
        os.mkdir(SAVE_PATH)
    train_file_name = 'train.txt'
    dev_file_name = 'dev.txt'
    test_file_name = 'test.txt'
    train_path = os.path.join(SAVE_PATH,train_file_name)
    dev_path = os.path.join(SAVE_PATH, dev_file_name)
    test_path = os.path.join(SAVE_PATH, test_file_name)
    f_train = open(train_path, 'a', encoding='utf-8')
    f_test = open(test_path, 'a', encoding='utf-8')
    f_dev = open(dev_path , 'a', encoding='utf-8')

    #保证结果一致性
    random.seed(42)
    for category in category_size.keys():
        cat_file = os.path.join(CATEGORY_PATH, category)
        #一个计数器,每次读取新文件时就重置
        count = 0
        with open(cat_file , 'rb') as f:
            # 获取到文件
            data = f.read().decode('utf-8', "ignore")
            lines = data.split('\n')
            size = len(lines)
            list = [x for x in range(size)]
            #使用random.shuffle将文本中的信息打乱
            random.shuffle(list)
            #先将整体数据划分为 98: 2,再从剩下的2份整体划分为1:1的训练集和验证集
            # radio是测试集的比例
            for i in list:
                if count < int(size*0.98):
                    f_train.writelines(lines[i])
                elif count >= int(size*0.98) and count<int(size*0.99):
                    f_dev.writelines(lines[i])
                else:
                    f_test.writelines(lines[i])
                count += 1
    f_train.close()
    f_test.close()
    f_dev.close()
    print("创建完成")
'''
主函数编写
'''
if __name__ == '__main__':
    category_size = get_all_data_size()
    for key in list(category_size.keys()):
        size = category_size.get(key)
        if size < 500:
            del category_size[key]
            print(' %s 数据不足被舍弃 ' % (key))

    sum = 0
    for key in category_size.keys():
        size = category_size.get(key)
        print(' %s : %d ' % (key, size))
    save_file(category_size)

最终得到了三个数据集文件。在查看一下数据的总量是否一致


SAVE_PATH = 'data/dataset/'
train_file_name = 'train.txt'
train_path = os.path.join(SAVE_PATH,train_file_name)
with open(train_path,'rb') as f:
    data = f.read().decode('gbk',"ignore")
    lines = data.split('\n')
    size1 = len(lines)
train_file_name = 'test.txt'
train_path = os.path.join(SAVE_PATH,train_file_name)
with open(train_path,'rb') as f:
    data = f.read().decode('gbk',"ignore")
    lines = data.split('\n')
    size2 = len(lines)
train_file_name = 'dev.txt'
train_path = os.path.join(SAVE_PATH,train_file_name)
with open(train_path,'rb') as f:
    data = f.read().decode('gbk',"ignore")
    lines = data.split('\n')
    size3 = len(lines)

print(size1+size2+size3)
sum = 0
for key in category_size.keys():
    size = category_size.get(key)
    sum += category_size.get(key)
print(sum)

在这里出了一些小问题,第一个检测是1410623,第二个检测是1410626。但好在相差数量极少,具体原因我也不太清楚。这样我们的数据集划分也结束了。下一步,就可以进行模型的训练与测试了。

参考文章:https://www.libinx.com/2018/text-classification-cnn-by-tensorflow/

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值