鉴于天池实验室不是很方便安装python工具包,这次用FastText来做文本分类任务的话,就用自定义一个简单的模型,下面开始附上代码。
# 绘图案例 an example of matplotlib
%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from keras.preprocessing.sequence import pad_sequences
import keras.layers as layers
from keras.layers import Embedding
from keras.layers import Dense
from keras.models import Sequential
from keras.layers import GlobalAveragePooling1D
#max_f:1000 我这边只截取了前1000个字符
def proprocessdata(filepath='datalab/72510/train_set.csv',max_f=1000):
#nrows:50000 是因为我这边在天池notebook里面写的,内存有限,所以就取了50000条
df=pd.read_csv(filepath,sep='\t',nrows=50000)
x_temp=df['text'].apply(lambda x:x.split(' '))
x=pad_sequences(x_temp,maxlen=max_f,dtype='int32',padding='post', truncating='pre', value=0.)
y=df['label'].values
#test_size:0.25 从数据集中取四分之一作为测试集
train_x,test_x,train_y,test_y=train_test_split(x,y,test_size=0.25)
return train_x,test_x,train_y,test_y
接下来自定义FastText的模型
def FastText():
model=Sequential()
model.add(Embedding(input_dim=8000,output_dim=100,input_length=1000))
model.add(GlobalAveragePooling1D())
model.add(Dense(units=14,activation='softmax'))
#使用sparse_categorical_crossentropy 方法,省去onehot独热编码的步骤
model.compile(optimizer='Adam',metrics=['accuracy'],loss='sparse_categorical_crossentropy')
return model
加载数据,开始训练模型
我这边做了30个轮训,对于37500条数据来说可能会过拟合,但是鉴于训练模型一般往过拟合的方向来做的原则来说,也算可以接受。
train_x,test_x,train_y,test_y=proprocessdata()
model=FastText()
history=model.fit(train_x,train_y,validation_split=0.2,epochs=30,batch_size=64)
训练了3万多条数据,准确率在验证集上大概在0.90左右,下面开始绘制损失值趋势图。
plt.plot(np.linspace(0,29,30),history.history['loss'],label='loss')
plt.plot(np.linspace(0,29,30),history.history['val_loss'],label='vloss')
plt.legend()
plt.show()
我们再看看准确率的趋势图。
plt.plot(np.linspace(0,29,30),history.history['acc'],label='acc')
plt.plot(np.linspace(0,29,30),history.history['val_acc'],label='vacc')
plt.legend()
plt.show()
接下来我在测试集上评估一下准确率再验证一下。
model.evaluate(test_x,test_y,batch_size=64)
12500/12500 [==============================] - 3s 241us/step
[0.38796657196521761, 0.90167999998092652]
目前看来37500条训练数据,文本只截取1000个字符的话,准确率大概也能达到0.9左右。