文章目录
BERT模型简介
BERT主要利用Transformer Encoder部分结合Masked Language Model,训练双向注意力模型应用到语言建模中。
BERT模型拆解
完整项目参考:https://github.com/huanghao128/bert_example
tensorflow模块导入
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import activations
from tensorflow.keras import layers
from tensorflow.keras import models
from tensorflow.keras import backend as K
multi-head attention
class MultiHeadAttention(keras.Model):
def __init__(self, hidden_size, num_heads, **kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_size = hidden_size // num_heads
self.WQ = layers.Dense(hidden_size, name="dense_query")
self.WK = layers.Dense(hidden_size, name="dense_key")
self.WV = layers.Dense(hidden_size, name="dense_value")
self.dense = layers.Dense(hidden_size)
def _split_heads(self, x, batch_size):
x = tf.reshape(x, shape=[batch_size, -1, self.num_heads, self.head_size])
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, query, key, value, mask):
# query: (batch, maxlen, hidden_size)
# key : (batch, maxlen, hidden_size)
# value: (batch, maxlen, hidden_size)
batch_size = tf.shape(query)[0]
# shape: (batch, maxlen, hidden_size)
query = self.WQ(query)
key = self.WK(key)
value = self.WV(value)
# shape: (batch, num_heads, maxlen, head_size)
query = self._split_heads(query, batch_size)
key = self._split_heads(key, batch_size)
value = self._split_heads(value, batch_size)
# shape: (batch, num_heads, maxlen, maxlen)
matmul_qk = tf.matmul(query, key, transpose_b=True)
# 缩放 matmul_qk
dk = tf.cast(query.shape[-1], tf.float32)
score = matmul_qk / tf.math.sqrt(dk)
if mask is not None:
mask = tf.cast(mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.float32)
score += (1 - mask) * -1e9
alpha = tf.nn.softmax(score)
context = tf.matmul(alpha, value)
context = tf.transpose(context, perm=[0, 2, 1, 3])
context = tf.reshape(context, (batch_size, -1, self.hidden_size))
output = self.dense(context)
return output
FeedForwardNetwork
class GELU(layers.Layer):
def __init__(self):
super(GELU, self).__init__()
def call(self, x):
cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2