NLP修行之路-day4 用自定义的FastText做新闻分类

鉴于天池实验室不是很方便安装python工具包,这次用FastText来做文本分类任务的话,就用自定义一个简单的模型,下面开始附上代码。

# 绘图案例 an example of matplotlib
%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from keras.preprocessing.sequence import pad_sequences
import keras.layers as layers
from keras.layers import Embedding
from keras.layers import Dense
from keras.models import Sequential
from keras.layers import GlobalAveragePooling1D

#max_f:1000 我这边只截取了前1000个字符
def proprocessdata(filepath='datalab/72510/train_set.csv',max_f=1000):
    #nrows:50000 是因为我这边在天池notebook里面写的,内存有限,所以就取了50000条
    df=pd.read_csv(filepath,sep='\t',nrows=50000)
    x_temp=df['text'].apply(lambda x:x.split(' '))
    x=pad_sequences(x_temp,maxlen=max_f,dtype='int32',padding='post', truncating='pre', value=0.)
    y=df['label'].values
    #test_size:0.25 从数据集中取四分之一作为测试集
    train_x,test_x,train_y,test_y=train_test_split(x,y,test_size=0.25)
    return train_x,test_x,train_y,test_y

接下来自定义FastText的模型

def FastText():
    model=Sequential()
    model.add(Embedding(input_dim=8000,output_dim=100,input_length=1000))
    model.add(GlobalAveragePooling1D())
    model.add(Dense(units=14,activation='softmax'))
    #使用sparse_categorical_crossentropy 方法,省去onehot独热编码的步骤
    model.compile(optimizer='Adam',metrics=['accuracy'],loss='sparse_categorical_crossentropy')
    return model

加载数据,开始训练模型
我这边做了30个轮训,对于37500条数据来说可能会过拟合,但是鉴于训练模型一般往过拟合的方向来做的原则来说,也算可以接受。

train_x,test_x,train_y,test_y=proprocessdata()
model=FastText()
history=model.fit(train_x,train_y,validation_split=0.2,epochs=30,batch_size=64)

训练了3万多条数据,准确率在验证集上大概在0.90左右,下面开始绘制损失值趋势图。

plt.plot(np.linspace(0,29,30),history.history['loss'],label='loss')
plt.plot(np.linspace(0,29,30),history.history['val_loss'],label='vloss')
plt.legend()
plt.show()

在这里插入图片描述

我们再看看准确率的趋势图。

plt.plot(np.linspace(0,29,30),history.history['acc'],label='acc')
plt.plot(np.linspace(0,29,30),history.history['val_acc'],label='vacc')
plt.legend()
plt.show()

在这里插入图片描述
接下来我在测试集上评估一下准确率再验证一下。

model.evaluate(test_x,test_y,batch_size=64)

12500/12500 [==============================] - 3s 241us/step
[0.38796657196521761, 0.90167999998092652]
目前看来37500条训练数据,文本只截取1000个字符的话,准确率大概也能达到0.9左右。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值