目前,中文多标签文本分类的方法主要有3种,今天我们来详细介绍及实践其中的一种,算法框架使用的是ALBERT。
一、介绍
- 此项目是在tensorflow版本1.14.0的基础上做的训练和测试。
- 任务类型为中文多标签文本分类,一共有K个标签:
。标签两两之间的关系有的是independent,有的是non independent。
- 模型的输入为一个sentence,输出为一个或者多个label。
- 简单介绍一个例子。
假设个人爱好的集合一共有6个元素:运动、旅游、读书、工作、睡觉、美食。
一般情况下,一个人的爱好有这其中的一个或者多个,那么这就是一个典型的多标签分类任务。
二、框架及算法
1、Placeholder
首先,我们需要设置一些占位符(Placeholder),占位符的作用是在训练和推理的过程中feed模型需要的数据。我们这里需要4个占位符,分别是input_ids、input_masks、segment_ids和label_ids。前面3个是我们了解的BERT输入特征,最后面一个是标签的id。
2、ALBERT token-vectors
从图中红色的框内可以看出,ALBERT需要传入3个参数(input_ids、input_masks、segment_ids),就可以得到我们所需要的一个2维向量output_layer:(batch_size, hidden_size)。
有人在这里就会好奇,为什么ALBERT输出的是一个2维向量,而不是一个3维向量(batch_size, sequence_length, hidden_size)呢?那我们来看一下源码,弄清楚self.model.get_pooled_output()的来历。
其中self.sequence_ouput其实就是我们所说的那个3维向量(batch_size, sequence_length, hidden_size)。我们对这个3维向量做了一个"pooler"的操作,从而使之变成了一个2维的向量,这个操作是上面蓝色方框内的内容。
蓝色方框内的解释为:”We "pool" the model by simply taking the hidden state corresponding to the first token. We assume that this has been pre-trained“。这句话怎么理解呢?意思是将整个句子的特征信息投射到句子第一个字的隐藏状态向量上面。并且,认为这个它是通过预训练得到的。
3、Full connection
最后,就是一个全连接层了。很简单,全连接层的作用是将output_layer投射到我们的标签上面。
4、上面3点在多标签文本分类和文本分类并没有区别。那么区别在哪里呢?
主要有以下3个区别:
- 交叉熵
- 输出概率
- 输出标签
4.1、交叉熵
在文本分类中,我们使用的交叉熵为tf.nn.softmax_cross_entropy_with_logits;在多标签文本分类中,我们使用的交叉熵则为tf.nn.sigmoid_cross_entropy_with_logits。这样做的原因:
- tf.nn.sigmoid_cross_entropy_with_logits测量离散分类任务中的概率误差,其中每个类是独立的而不是互斥的。这适用于多标签分类问题。
- tf.nn.softmax_cross_entropy_with_logits测量离散分类任务中的概率误差,其中类之间是互斥的(每个条目恰好在一个类中)。这适用多分类问题。
4.2、输出概率
在文本分类中,输出概率为tf.nn.softmax(logits, axis=-1);在多标签文本分类中,输出概率为tf.nn.sigmoid(logits)。这样做的原因:
- 在简单的二进制分类中,sigmoid和softmax没有太大的区别。
- 在多分类的情况下,sigmoid允许处理非独占标签(也称为多标签),而softmax处理独占类。
4.3、输出标签
在文本分类(多元文本分类)中,label_ids的维度为(batch_size);在多标签文本分类中,它的维度为(batch_size,num_labels)。这样做的原因:
- 在多元文本分类中,最后得到的标签只有一个,并且必须是其中的一个。
- 在多标签文本分类中,最后得到的标签可能有1个或者多个。
一般的多元分类是通过tf.argmax(logits)实现,返回的是最大的那个数值所在的label_id,因为logits对应每一个label_id都有一个概率。但是,在多标签分类中,我们需要得到的是每一个标签是否可以作为输出标签,所以每一个标签可以作为输出标签的概率都会量化为一个0到1之间的值。所以当某一个标签对应输出概率小于0.5时,我们认为它不能作为当前句子的输出标签;反之,如果大于等于0.5,那么它代表了当前句子的输出标签之一。
三、实践及框架图
1、框架图
2、模型Loss和Accuracy变化曲线图
我们可以发现,这里的Loss和Accuracy的变化趋势和多元文本分类有较大的区别。在多标签文本分类的训练过程中,Loss的下降幅度非常快,但并不代表模型的收敛快。在多元文本分类的训练过程中,Loss一般在0.1-0.2之间的时候,模型基本上已经收敛。但是,在多标签文本分类(当前框架下)的过程中,当Loss到达0.1-0.2时,模型收敛还需较多的steps。根据训练经验,在多标签文本分类(这个框架下)的情况下,Loss往往要达到0.0001-0.001之间,模型才收敛。
四、代码链接
hellonlp/classifier_multi_labelgithub.com其它相关文章:HelloNLP:多标签文本分类介绍,以及对比训练zhuanlan.zhihu.com