MultiHeadAttention.py
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
"""
class MultiHeadAttention(keras.Model):
# https://machinetalk.org/2019/04/29/create-the-transformer-with-tensorflow-2-0/
def __init__(self, model_size, h, dropout):
super(MultiHeadAttention, self).__init__()
self.query_size = model_size // h
self.key_size = model_size // h
self.value_size = model_size // h
self.h = h
self.wq = [layers.Dense(self.query_size) for _ in range(h)]
self.wk = [layers.Dense(self.key_size) for _ in range(h)]
self.wv = [layers.Dense(self.value_size) for _ in range(h)]
self.wo = layers.Dense(model_size)
self.dropout = layers.Dropout(dropout)
def call(self, query, value):
# query has shape (batch, query_len, model_size)
# value has shape (batch, value_len, model_size)
heads = []
for i in range(self.h):
score = self.dropout(tf.matmul(self.wq[i](query), self.wk[i](value), transpose_b=True))
# Here we scale the score as described in the paper
score /= tf.math.sqrt(tf.dtypes.cast(self.key_size, tf.float32))
# score has shape (batch, query_len, value_len)
alignment = tf.nn.softmax(score, axis=2)
# alignment has shape (batch, query_len, value_len)
head = tf.matmul(alignment, self.wv[i](value))
# head has shape (batch, decoder_len, value_size)
heads.append(head)
# Concatenate all the attention heads
# so that the last dimension summed up to model_size
heads = tf.concat(heads, axis=2)
heads = self.wo(heads)
# heads has shape (batch, query_len, model_size)
return heads
"""
class MultiHeadAttention(keras.Model):
# https://www.tensorflow.org/tutorials/text/transformer
def __init__(self, d_model, num_heads, dropout):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
assert d_model % self.num_heads == 0
self.depth = d_model // self.num_heads
self.wq = layers.Dense(d_model)
self.wk = layers.Dense(d_model)
self.wv = layers.Dense(d_model)
self.dropout = layers.Dropout(dropout)
self.dense = layers.Dense(d_model)
def scaled_dot_product_attention(self, q, k, v, mask):
"""计算注意力权重。
q, k, v 必须具有匹配的前置维度。
k, v 必须有匹配的倒数第二个维度,例如:seq_len_k = seq_len_v。
虽然 mask 根据其类型(填充或前瞻)有不同的形状,
但是 mask 必须能进行广播转换以便求和。
参数:
q: 请求的形状 == (..., seq_len_q, depth)
k: 主键的形状 == (..., seq_len_k, depth)
v: 数值的形状 == (..., seq_len_v, depth_v)
mask: Float 张量,其形状能转换成
(..., seq_len_q, seq_len_k)。默认为None。
返回值:
输出,注意力权重
"""
matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k)
# 缩放 matmul_qk
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
# 将 mask 加入到缩放的张量上。
if mask is not None:
scaled_attention_logits += (mask * -1e9)
# softmax 在最后一个轴(seq_len_k)上归一化,因此分数
# 相加等于1。
attention_weights = self.dropout(tf.nn.softmax(scaled_attention_logits, axis=-1)) # (..., seq_len_q, seq_len_k)
output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)
return output, attention_weights
def split_heads(self, x, batch_size):
"""分拆最后一个维度到 (num_heads, depth).
转置结果使得形状为 (batch_size, num_heads, seq_len, depth)
"""
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, q, k, v, mask):
batch_size = tf.shape(q)[0]
q = self.wq(q) # (batch_size, seq_len, d_model)
k = self.wk(k) # (batch_size, seq_len, d_model)
v = self.wv(v) # (batch_size, seq_len, d_model)
q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)
v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)
# scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
# attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
scaled_attention, attention_weights = self.scaled_dot_product_attention(
q, k, v, mask)
scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len_q, num_heads, depth)
concat_attention = tf.reshape(scaled_attention,
(batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model)
output = self.dense(concat_attention) # (batch_size, seq_len_q, d_model)
return output, attention_weights
rztx.py
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import activations
import tensorflow_addons as tfa
from MultiHeadAttention import MultiHeadAttention
class RZTXEncoderLayer(keras.Model):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='relu'):
super(RZTXEncoderLayer,self).__init__()
# d_model = E Q:[L,N,E] K:[S,N,E] V:[S,N,E] bs = N
self.self_attn = MultiHeadAttention(d_model, nhead, dropout=dropout) # 自注意力模型,等待tensorflow更新多头
# Implementation of Feedforward model
self.linear1 = layers.Dense(dim_feedforward) # 线性1
self.dropout = layers.Dropout(dropout)
self.linear2 = layers.Dense(d_model) # 线性2
self.dropout1 = layers.Dropout(dropout)
self.dropout2 = layers.Dropout(dropout)
self.resweight = tf.Variable(0.0,trainable=True) # 学习参数alpha
if activation == "relu":
self.activation = activations.relu
elif activation == "gelu":
self.activation = tfa.activations.gelu
def __setstate__(self, state):
if 'activation' not in state:
state['activation'] = activations.relu
super().__setstate__(state)
def call(self, src, mask=None):
# Self attention layer
src2 = src
src2,_ = self.self_attn(src2, src2, src2, mask) # [l,bs,emb]
src2 = src2 * self.resweight
src = src + self.dropout1(src2) # [l,bs,emb]
# Pointiwse FF Layer 全连接层
src2 = src
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
src2 = src2 * self.resweight
src = src + self.dropout2(src2)
return src
class RZTXDecoderLayer(keras.Model):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
super(RZTXDecoderLayer,self).__init__()
self.self_attn = MultiHeadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = MultiHeadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = layers.Dense(dim_feedforward)
self.dropout = layers.Dropout(dropout)
self.linear2 = layers.Dense(d_model)
self.dropout1 = layers.Dropout(dropout)
self.dropout2 = layers.Dropout(dropout)
self.dropout3 = layers.Dropout(dropout)
self.resweight = tf.Variable(0.0,trainable=True)
if activation == "relu":
self.activation = activations.relu
elif activation == "gelu":
self.activation = tfa.activations.gelu
def call(self, tgt, memory, tgt_mask=None, memory_mask=None):
tgt2,_ = self.self_attn(tgt, tgt, tgt, tgt_mask)
tgt = tgt + self.dropout1(tgt2) * self.resweight
# Q = tgt; K = memory; V = memory
tgt2,_ = self.multihead_attn(tgt, memory, memory, memory_mask)
tgt = tgt + self.dropout2(tgt2) * self.resweight
if hasattr(self, "activation"):
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
else: # for backward compatibility
tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2) * self.resweight
return tgt
"""
encoder_layer = RZTXEncoderLayer(d_model=512, nhead=8)
src = tf.random.normal([32, 10, 512]) # [bs,q,emb]
out = encoder_layer(src)
print(out.shape)
decoder_layer = RZTXDecoderLayer(d_model=512, nhead=8)
memory = tf.random.normal([32, 10, 512])
tgt = tf.random.normal([32, 20, 512])
out = decoder_layer(tgt, memory)
print(out.shape)
"""