使用BERT对搜狐新闻文本进行分类(一)
本人小白,刚刚接触NLP领域,通过写博客方式来记录写一下学习的过程。如果有文章中有不对或可以改善指出,希望大家可以帮我指出。谢谢大家。
处理搜狐新闻数据文本
数据集来源:http://www.sogou.com/labs/resource/cs.php,下载的是完整版648M。
1.这里直接使用解压后的文件“news_sohusite_xml.dat”(推荐先使用小的文本进行试验处理)
'''
由于一些转码问题,我们先通过二进制方式读取,转换为gbk格式
'''
SOURCE_PATH = 'data/news_sohusite_xml.smarty.dat'
with open(SOURCE_PATH,'rb') as f:
source_data = f.read().decode('gbk',"ignore")
2.接下来需要对格式进行转换。
可以看到一般来说新闻的类别隐藏在url域中,如gongyi对应公益。对应新闻的题目,content对应新闻的正文。
这里使用python的re模块来对内容进行提取,在这里我将类别,标题,正文都提取取出来,在后面的训练时可以选择是否对标题内容进行训练。
数据格式为
<doc>
<url>页面URL</url>
<docno>页面ID</docno>
<contenttitle>页面标题</contenttitle>
<content>页面内容</content>
</doc>
#提取新闻的内容
def extract_class_extract(doc):
#先对<url>进行切割 分割为 <url>与其后面的内容,选取其后面的内容,再对</url>进行切割,选取</url>前面的内容
url = doc.split('<url>')[1].split('</url>')[0]
# contexttitle同理
contenttitle = doc.split('<contenttitle>')[1].split('</contenttitle>')[0]
#context同理
content = doc.split('<content>')[1].split('</content>')[0]
#之后从url中提取类别,使用正则表达式,这里得到的是一个列表但其实只有一项,我们直接提取出来
category = re.findall(r"http://(.*?).sohu.com/", url)
category = category[0]
#这样提取的信息可能存在子领域(如try.women,apple.it等,但我们目前是为了分类大领域(只要women或者it),所以需要进一步提取
#以'.'分割后的最后一个为大领域
category = category.split('.')[-1]
return category,contenttitle,content
3.将提取的新闻内容分别放入不同的文件中保存
def write_category_file(category,content):
dir_name = 'data/category/'
if not os.path.exists(dir_name):
os.mkdir(dir_name)
file_name = category+'.txt'
path = os.path.join(dir_name, file_name)
with open(path, 'a', encoding='gbk') as f:
#这里选择先将类别写入,并使用\t来分割正文内容
f.write(category + '\t' + content + '\n')
4.这一部分的代码整合
'''
由于一些转码问题,我们先通过二进制方式读取,转换为gbk格式
'''
with open(SOURCE_PATH,'rb') as f:
source_data = f.read().decode('gbk',"ignore")
#以<doc>作为切割
docs = source_data.split('<doc>\n')
for doc in docs:
if doc:
category,contenttitle,content = extract_class_extract(doc)
#这里先没有使用到contenttitle
write_category_file(category,content)
对分类好的数据进行划分,划分为train,dev,test三个数据集
1.先查看一下我们各个数据集的数据量
'''
输出所有的数据集的个数
'''
def get_data_size(file_name):
path = os.path.join(CATEGORY_PATH,file_name)
with open(path,'rb') as f:
data = f.read().decode('utf-8',"ignore")
lines = data.split('\n')
return len(lines)
'''
输出所有数据集的数据个数
'''
def get_all_data_size():
#使用字典来保存每一个类型数据集的大小
category_size = {}
for file_name in os.listdir(CATEGORY_PATH):
data_size = get_data_size(file_name)
#将数据数量添加到字典中
category_size[file_name] = data_size
return category_size
category_size = get_all_data_size()
for key in category_size.keys():
print(' %s : %d ' % (key,category_size.get(key)))
2008.txt : 843
2010.txt : 559
2012.txt : 32
astro.txt : 361
auto.txt : 138577
baobao.txt : 2694
book.txt : 6533
bschool.txt : 231
business.txt : 27490
campus.txt : 3
chihe.txt : 533
club.txt : 350
cul.txt : 1925
dm.txt : 30
expo2010.txt : 2
fund.txt : 5016
games.txt : 43
gd.txt : 1844
goabroad.txt : 1107
gongyi.txt : 240
green.txt : 522
health.txt : 23410
it.txt : 199872
korea.txt : 110
learning.txt : 13013
media.txt : 670
men.txt : 1095
money.txt : 10617
news.txt : 86053
roll.txt : 720958
s.txt : 8679
sh.txt : 1299
sports.txt : 44537
stock.txt : 52931
travel.txt : 2180
tuan.txt : 2
tv.txt : 1644
v.txt : 9
women.txt : 5883
yule.txt : 50139
这里我们看到有一些分类文本中的数量极少,并且还存在着一些意义不明的内容,我们选择舍弃这些极少量的数据文本。
2.筛选数据
在这里我只选择含有500条数据以上的数据集。
category_size = get_all_data_size()
for key in list(category_size.keys()):
size = category_size.get(key)
if size < 500:
del category_size[key]
print(' %s 数据不足被舍弃 ' % (key))
for key in category_size.keys():
size = category_size.get(key)
print(' %s : %d ' % (key, size))
得到了我们留下的文本列表
2012.txt 数据不足被舍弃
astro.txt 数据不足被舍弃
bschool.txt 数据不足被舍弃
campus.txt 数据不足被舍弃
club.txt 数据不足被舍弃
dm.txt 数据不足被舍弃
expo2010.txt 数据不足被舍弃
games.txt 数据不足被舍弃
gongyi.txt 数据不足被舍弃
korea.txt 数据不足被舍弃
tuan.txt 数据不足被舍弃
v.txt 数据不足被舍弃
2008.txt : 843
2010.txt : 559
auto.txt : 138577
baobao.txt : 2694
book.txt : 6533
business.txt : 27490
chihe.txt : 533
cul.txt : 1925
fund.txt : 5016
gd.txt : 1844
goabroad.txt : 1107
green.txt : 522
health.txt : 23410
it.txt : 199872
learning.txt : 13013
media.txt : 670
men.txt : 1095
money.txt : 10617
news.txt : 86053
roll.txt : 720958
s.txt : 8679
sh.txt : 1299
sports.txt : 44537
stock.txt : 52931
travel.txt : 2180
tv.txt : 1644
women.txt : 5883
yule.txt : 50139
3.进行划分数据集并将其保存到对应的文件
整体而言,我们的数据集的数量还是比较可观的,所以我按照98:1:1的比例来划分训练集,验证集,和测试集。
由于我没有找到好的函数来帮助将文本txt文件随机划分为三个训练集,所以我采用比较笨的方法,先获取文件长度,使用random.shuffle对其打乱顺序,再根据比例将其中内容写入到三个文件中。
'''
数据集进行切割,并保存
整体而言,我们的数据集的数量还是比较可观的,所以我按照98:1:1的比例来划分训练集,验证集,和测试集
'''
def save_file(category_size):
SAVE_PATH = 'data/dataset/'
if not os.path.exists(SAVE_PATH):
os.mkdir(SAVE_PATH)
train_file_name = 'train.txt'
dev_file_name = 'dev.txt'
test_file_name = 'test.txt'
train_path = os.path.join(SAVE_PATH,train_file_name)
dev_path = os.path.join(SAVE_PATH, dev_file_name)
test_path = os.path.join(SAVE_PATH, test_file_name)
f_train = open(train_path, 'a', encoding='utf-8')
f_test = open(test_path, 'a', encoding='utf-8')
f_dev = open(dev_path , 'a', encoding='utf-8')
#保证结果一致性
random.seed(42)
for category in category_size.keys():
cat_file = os.path.join(CATEGORY_PATH, category)
#一个计数器,每次读取新文件时就重置
count = 0
with open(cat_file , 'rb') as f:
# 获取到文件
data = f.read().decode('utf-8', "ignore")
lines = data.split('\n')
size = len(lines)
list = [x for x in range(size)]
#使用random.shuffle将文本中的信息打乱
random.shuffle(list)
#先将整体数据划分为 98: 2,再从剩下的2份整体划分为1:1的训练集和验证集
# radio是测试集的比例
for i in list:
if count < int(size*0.98):
f_train.writelines(lines[i])
elif count >= int(size*0.98) and count<int(size*0.99):
f_dev.writelines(lines[i])
else:
f_test.writelines(lines[i])
count += 1
f_train.close()
f_test.close()
f_dev.close()
print("创建完成")
'''
主函数编写
'''
if __name__ == '__main__':
category_size = get_all_data_size()
for key in list(category_size.keys()):
size = category_size.get(key)
if size < 500:
del category_size[key]
print(' %s 数据不足被舍弃 ' % (key))
sum = 0
for key in category_size.keys():
size = category_size.get(key)
print(' %s : %d ' % (key, size))
save_file(category_size)
最终得到了三个数据集文件。在查看一下数据的总量是否一致
SAVE_PATH = 'data/dataset/'
train_file_name = 'train.txt'
train_path = os.path.join(SAVE_PATH,train_file_name)
with open(train_path,'rb') as f:
data = f.read().decode('gbk',"ignore")
lines = data.split('\n')
size1 = len(lines)
train_file_name = 'test.txt'
train_path = os.path.join(SAVE_PATH,train_file_name)
with open(train_path,'rb') as f:
data = f.read().decode('gbk',"ignore")
lines = data.split('\n')
size2 = len(lines)
train_file_name = 'dev.txt'
train_path = os.path.join(SAVE_PATH,train_file_name)
with open(train_path,'rb') as f:
data = f.read().decode('gbk',"ignore")
lines = data.split('\n')
size3 = len(lines)
print(size1+size2+size3)
sum = 0
for key in category_size.keys():
size = category_size.get(key)
sum += category_size.get(key)
print(sum)
在这里出了一些小问题,第一个检测是1410623,第二个检测是1410626。但好在相差数量极少,具体原因我也不太清楚。这样我们的数据集划分也结束了。下一步,就可以进行模型的训练与测试了。
参考文章:https://www.libinx.com/2018/text-classification-cnn-by-tensorflow/