使用keras手写layerNormaliztion , attention,self_attention,multi_head_attention代码

from keras.models import Model, load_model
from keras.layers import Input, BatchNormalization, Activation, Add, Multiply, Dot
from keras.layers import Embedding, Permute, Reshape, GaussianNoise
from keras.layers.core import Dropout, Lambda, Dense, Flatten
from keras.layers.convolutional import Conv2D, Conv2DTranspose, UpSampling2D, Conv1D
from keras.layers.pooling import GlobalMaxPooling1D, GlobalAveragePooling2D, AveragePooling2D
from keras.layers.merge import concatenate, Concatenate
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, LearningRateScheduler
from keras.optimizers import Adam, SGD, Nadam
from keras import backend as K

from keras.engine.topology import Layer
import tensorflow as tf


class LayerNormalization(keras.layers.Layer):
    
    def __init__(self,
                 center = True,
                 scale = True,
                 epsilon = True,
                 gamma_initializer = "ones",
                 beta_initializer = "zeros",
                 gamma_regularizer = None,
                 beta_regularizer = None,
                 gamma_constraint = None,
                 beta_constraint = None,
                 **kwargs):
        """
        Layer normalization layer
            refference: [Layer Normalization](https://arxiv.org/pdf/1607.06450.pdf)
                :param center: Add an offset parameter if it is True.
                :param scale: Add a scale parameter if it is True.
                :param epsilon: Epsilon for calculating variance.
                :param gamma_initializer: Initializer for the gamma weight.
                :param beta_initializer: Initializer for the beta weight.
                :param gamma_regularizer: Optional regularizer for the gamma weight.
                :param beta_regularizer: Optional regularizer for the beta weight.
                :param gamma_constraint: Optional constraint for the gamma weight.
                :param beta_constraint: Optional constraint for the beta weight.
                :param kwargs:
        """        
        super(LayerNormalization,self).__init__(**kwargs)
        self.supports_masking = True
        self.center = center
        self.scale = scale
        if epsilon is None:
            epsilon = K.epsilon() * K.epsilon()
        self.epsilon = epsilon
        self.gamma_initializer = keras.initializers.get(gamma_initializer)
        self.beta_initializer = keras.initializers.get(beta_initializer)
        self.gamma_regularizer = keras.regularizers.get(gamma_regularizer)
        self.beta_regularizer = keras.regularizers.get(beta_regularizer)
        self.gamma_constraint = keras.constraints.get(gamma_constraint)
        self.beta_constraint = keras.constraints.get(beta_constraint)
        self.gamma, self.beta = None, None

    def get_config(self):
        config = {
            'center': self.center,
            'scale': self.scale,
            'epsilon': self.epsilon,
            'gamma_initializer': keras.initializers.serialize(self.gamma_initializer),
            'beta_initializer': keras.initializers.serialize(self.beta_initializer),
            'gamma_regularizer': keras.regularizers.serialize(self.gamma_regularizer),
            'beta_regularizer': keras.regularizers.serialize(self.beta_regularizer),
            'gamma_constraint': keras.constraints.serialize(self.gamma_constraint),
            'beta_constraint': keras.constraints.serialize(self.beta_constraint),
        }
        base_config = super(LayerNormalization, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape

    def compute_mask(self, inputs, input_mask=None):
        return input_mask

    def build(self, input_shape):
        shape = input_shape[-1:]
        if self.scale:
            self.gamma = self.add_weight(
                shape=shape,
                initializer=self.gamma_initializer,
                regularizer=self.gamma_regularizer,
                constraint=self.gamma_constraint,
                name='gamma',
            )
        if self.center:
            self.beta = self.add_weight(
                shape=shape,
                initializer=self.beta_initializer,
                regularizer=self.beta_regularizer,
                constraint=self.beta_constraint,
                name='beta',
            )
        super(LayerNormalization, self).build(input_shape)

    def call(self, inputs, training=None):
        mean = K.mean(inputs, axis=-1, keepdims=True)
        variance = K.mean(K.square(inputs - mean), axis=-1, keepdims=True)
        std = K.sqrt(variance + self.epsilon)
        outputs = (inputs - mean) / std
        if self.scale:
            outputs *= self.gamma
        if self.center:
            outputs += self.beta
        return outputs        



def attention(x, n_factor, dropout):
    """
    attention layer @ Morxrc
        attention - 类似于RNN的效果,收集全局信息,方便并行化处理
        Conv1D - 1D卷积,Q,K,V,分别为Query,Key,Value, Q,K为比较相似度所使用,Value为节点本身信息的向量
                 目前主流的NLP研究中,key和value常常都是同一个,即key=value。n_factor卷积核数量也就是输出维度
                 1为kernel_size 表示卷积窗口的长度。
        Permute - Permute层是置换模式,即(2,1)就是置换输入的第一和第二个维度,即转置所用
        axis = -1 - 在最后一个维度进行操作
    """
    x_Q = Conv1D(n_factor,1,activation="linear",
                 kernel_initializer='glorot_uniform',
                 bias_initializer='glorot_uniform',
                )(x)
    x_K = Conv1D(n_factor, 1, activation='linear', 
                  kernel_initializer='glorot_uniform',
                  bias_initializer='glorot_uniform',
                 )(x)
    x_V =  Conv1D(n_factor, 1, activation='linear', 
                  kernel_initializer='glorot_uniform',
                  bias_initializer='glorot_uniform',
                 )(x)

    x_KT = Permute((2,1))(x_K)
    res = Lambda(lambda c:K.batch_dot(c[0],c[1])/np.sqrt(n_factor))([x_Q,x_KT])
    att = Lambda(lambda c:K.softmax(c,axis=-1))(res)
    att = Lambda(lambda c:K.batch_dot(c[0],c[1]))([att,x_V])
    return att

def self_attention(x, n_factor, dropout):
    att = attention(x,n_factor,dropout)
    att = LayerNormalization()(att)
    if dropout > 0:
        att = Dropout(dropout)(att)
    x = Add()([x,att])
    return x

def multi_head_attention(x,n_factor,n_head,dropout):
    n_factor_head = n_factor // n_head
    # n_factor_head 
    heads = [attention(x,n_factor_head,dropout) for i in range(n_head)]
    att = Concatenate()(heads)
    att = Dense(n_factor,
                kernel_initializer = "glorot_uniform",
                bias_initializer = "glorot_uniform",
               )(att)
    # add & Norm
    x = Add()([x,att])
    x = LayerNormalization()(x)
    if dropout > 0:
        x = Dropout(dropout)(x)
    return x

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值