文本分类:BiRNN+Attention(tensorflow2.0实现)

个人其他链接

github
blog

BiRNN+Attention

完整代码在github
此处对于注意力机制的实现参照了论文 Feed-Forward Networks with Attention Can Solve Some Long-Term Memory Problems

此处实现的网络结构:

基于tensorflow2.0的keras实现

自定义 Attention layer

这是tensorflow2.0推荐的写法,继承Layer,自定义Layer

需要注意的几点:

  • 如果需要使用到其他Layer结构或者Sequential结构,需要在__init__()函数里赋值
  • 在build()里面构建权重参数, 每个参数需要赋值name
    • 如果参数不给name,当训练到第2个epoch时会报错:AttributeError: ‘NoneType’ object has no attribute ‘replace’
  • 在call()里写计算逻辑
  • 这里实现的Attention是将GRU各个step的output作为key和value,增加一个参数向量W作为query,主要是为了计算GRU各个step的output的权重,最后加权求和得到Attention的输出
# -*- coding: utf-8 -*-
# @Time : 2020/4/21 13:55
# @Author : zdqzyx
# @File : attention.py
# @Software: PyCharm

from tensorflow.keras import  initializers,regularizers,constraints
from  tensorflow.keras.layers import Layer
import tensorflow as tf

class Attention(Layer):
    def __init__(self,
                 W_regularizer=None,
                 b_regularizer=None,
                 W_constraint=None,
                 b_constraint=None,
                 bias=True,
                 **kwargs
                 ):
        """
        Keras Layer that implements an Attention mechanism for temporal data.
        Supports Masking.
        Follows the work of Raffel et al. [https://arxiv.org/abs/1512.08756]
        # Input shape
            3D tensor with shape: `(samples, steps, features)`.
        # Output shape
            2D tensor with shape: `(samples, features)`.
        :param kwargs:
        Just put it on top of an RNN Layer (GRU/LSTM/SimpleRNN) with return_sequences=True.
        The dimensions are inferred based on the output shape of the RNN.
        Example:
            # 1
            model.add(LSTM(64, return_sequences=True))
            model.add(Attention())
            # next add a Dense layer (for classification/regression) or whatever...
            # 2
            hidden = LSTM(64, return_sequences=True)(words)
            sentence = Attention()(hidden)
            # next add a Dense layer (for classification/regression) or whatever...
        """
        super(Attention, self).__init__()
        self.bias = bias
        self.init = initializers.get('glorot_uniform')

    def build(self, input_shape):
        '''
        :param input_shape:
        :return:
        '''
        self.output_dim = input_shape[-1]
        self.W = self.add_weight(
                                 name='{}_W'.format(self.name),
                                 shape=(input_shape[2], 1),
                                 initializer=self.init,
                                 trainable=True
                                 )
        if self.bias:
            self.b = self.add_weight(
                                     name='{}_b'.format(self.name),
                                     shape=(input_shape[1], 1),
                                     initializer='zero',
                                     trainable=True
                                     )
        else:
            self.b = None

        self.built = True

    def compute_mask(self, inputs, mask=None):
        return None

    def call(self, inputs, mask=None):
        # (N, step, d), (d, 1)  ==>   (N, step, 1)
        e = tf.matmul(inputs, self.W, )
        if self.bias:
            e += self.b
        e = tf.tanh(e)
        a = tf.nn.softmax(e, axis=1)
        # (N, step, d) (N, step, 1) ====> (N, step, d)
        c = inputs*a
        # (N, d)
        c = tf.reduce_sum(c, axis=1)
        return c

    def get_config(self):
        return {'units': self.output_dim}


if __name__=='__main__':
    x = tf.ones((2, 5, 10))
    att = Attention()
    y = att(x)
    print(y.shape)
    print(y)
    print(att.get_config())

自定义Model 构建

  • 其中可以注意的是:允许定义Sequential来包裹常用block,比如下面的 point_wise_feed_forward_network()函数,包裹了n个全连接层。然后在自定义模型的__init__()里初始化使用
# -*- coding: utf-8 -*-
# @Time : 2020/4/21 13:50
# @Author : zdqzyx
# @File : text_birnn_att.py
# @Software: PyCharm


import  tensorflow as tf
from tensorflow.keras.layers import Embedding, Dense, GRU, Bidirectional
from tensorflow.keras import Model
from attention import Attention

def point_wise_feed_forward_network(dense_size):
    ffn = tf.keras.Sequential()
    for size in dense_size:
        ffn.add(Dense(size, activation='relu'))
    return ffn

class TextBiRNNAtt(Model):

    def __init__(self,
                 maxlen,
                 max_features,
                 embedding_dims,
                 class_num,
                 last_activation='softmax',
                 dense_size=None
                 ):
        '''
        :param maxlen: 文本最大长度
        :param max_features: 词典大小
        :param embedding_dims: embedding维度大小
        :param class_num:
        :param last_activation:
        '''
        super(TextBiRNNAtt, self).__init__()
        self.maxlen = maxlen
        self.max_features = max_features
        self.embedding_dims = embedding_dims
        self.class_num = class_num
        self.last_activation = last_activation
        self.dense_size = dense_size

        self.embedding = Embedding(input_dim=self.max_features, output_dim=self.embedding_dims, input_length=self.maxlen)
        self.bi_rnn = Bidirectional(layer=GRU(units=128, activation='tanh', return_sequences=True), merge_mode='concat' ) # LSTM or GRU
        self.attention = Attention()
        if self.dense_size is not None:
            self.ffn = point_wise_feed_forward_network(dense_size)
        self.classifier = Dense(self.class_num, activation=self.last_activation)

    def call(self, inputs, training=None, mask=None):
        if len(inputs.get_shape()) != 2:
            raise ValueError('The rank of inputs of TextBiRNNAtt must be 2, but now is {}'.format(inputs.get_shape()))
        if inputs.get_shape()[1] != self.maxlen:
            raise ValueError('The maxlen of inputs of TextBiRNNAtt must be %d, but now is %d' % (self.maxlen, inputs.get_shape()[1]))

        emb = self.embedding(inputs)
        x = self.bi_rnn(emb)
        x = self.attention(x)
        if self.dense_size is not None:
            x = self.ffn(x)
        output = self.classifier(x)
        return output

    def build_graph(self, input_shape):
        input_shape_nobatch = input_shape[1:]
        self.build(input_shape)
        inputs = tf.keras.Input(shape=input_shape_nobatch)
        if not hasattr(self, 'call'):
            raise AttributeError("User should define 'call' method in sub-class model!")
        _ = self.call(inputs)

if __name__=='__main__':
    model = TextBiRNNAtt(maxlen=400,
                        max_features=5000,
                        embedding_dims=100,
                        class_num=2,
                        last_activation='softmax',
                        # dense_size=[128, 64],
                        dense_size = None
                        )
    model.build_graph(input_shape=(None, 400))
    model.summary()
  • 2
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值