1、 概述
- 注意力机制其实在每个人在接受数据的时候所对重要性特征的表现,比如一个手机销售员介绍一款手机,男生、女生对手机介绍的重点是不同的,男生更加注意6G+256G,最新骁龙处理器;女生更加注意xx亮樱桃色,粉红色新款手机。
2、说明
对于这种注意力机制,其实就是人脑和人眼对现实世界的大量数据的过滤,只抓住每个人自己注意的重点。这里我们就可以把人也当作一个注意力模型,其实就是大量的数据通过注意力模型,输出数据特征。
3、使用
- 计算机视觉:我们先用视觉的注意力方案就是,因为可以很直观的看到模型是怎么注意到重要内容的。
我们可以看到一篇论文中所显示的如下: 这段图片生成,我们就可以看到对于一只海鸥的注意力机制,注意海鸥飞翔的时候,最重要注意到的是双翼,然后再注意到的是海鸥的飞翔。
- 自然语言处理:注意力机制也是可以用在NLP上的,对于一段话,不同的注意力模型可以输出不同的重点特征。
我们通过不停的去观察,一段话的全局信息和局部信息,然后把局部信息和全局信息融合起来然后结合label进行训练注意力模型。
该模型是注意力模型中的一个典型的模型,论文如下论文地址,该Attention模型机制在使用decoder用“爱”预测的时候,会用到“爱”这里的state,结合上encoder中所有词的信息,生成一个注意力权重,比如模型会更加注意encoder中的“爱”和“莫烦”。然后再结合这个权重,再次把权重施加到encoder中的每个state,这样就有了对于不同state step的不一样关注度,注意力的结果(context)也在这里产生。接着,把context和decoder那边的信息再次结合,最终输出注意后的答案。
该论文计算attention score的方式如下:
我们实施一个项目【翻译项目】:
该项目我们使用注意力机制模型,来实现一个日期翻译的例子。
在翻译的例子中,实际上我们需要构建一个Encoder,一个Decoder。注意力模型就是我们要在Decoder在生成语言的时候,也应该注意到Encoder的对应部分。
# 中文的 "年-月-日" -> "day/month/year"
"98-02-26" -> "26/Feb/1998"
该例子就是将中文的日期形式转换成英文的格式,同时我们也会输出类似的注意力图,让我们知道在模型生成某些字的时候,它究竟依据的是哪里的信息。
该注意力图中x轴就是中文的年月日,y轴就是要生成的英文日月年。
代码:
- 生成数据;
- 建立模型;
- 训练模型;
# [Effective Approaches to Attention-based Neural Machine Translation](https://arxiv.org/pdf/1508.04025.pdf)
import tensorflow as tf
from tensorflow import keras
import numpy as np
import utils # this refers to utils.py in my [repo](https://github.com/MorvanZhou/NLP-Tutorials/)
import tensorflow_addons as tfa
import pickle
class Seq2Seq(keras.Model):
def __init__(self, enc_v_dim, dec_v_dim, emb_dim, units, attention_layer_size, max_pred_len, start_token, end_token):
super().__init__()
self.units = units
# encoder
self.enc_embeddings = keras.layers.Embedding(
input_dim=enc_v_dim, output_dim=emb_dim, # [enc_n_vocab, emb_dim]
embeddings_initializer=tf.initializers.RandomNormal(0., 0.1),
)
self.encoder = keras.layers.LSTM(units=units, return_sequences=True, return_state=True)
# decoder
self.attention = tfa.seq2seq.LuongAttention(units, memory=None, memory_sequence_length=None)
self.decoder_cell = tfa.seq2seq.AttentionWrapper(
cell=keras.layers.LSTMCell(units=units),
attention_mechanism=self.attention,
attention_layer_size=attention_layer_size,
alignment_history=True, # for attention visualization
)
self.dec_embeddings = keras.layers.Embedding(
input_dim=dec_v_dim, output_dim=emb_dim, # [dec_n_vocab, emb_dim]
embeddings_initializer=tf.initializers.RandomNormal(0., 0.1),
)
decoder_dense = keras.layers.Dense(dec_v_dim) # output layer
# train decoder
self.decoder_train = tfa.seq2seq.BasicDecoder(
cell=self.decoder_cell,
sampler=tfa.seq2seq.sampler.TrainingSampler(), # sampler for train
output_layer=decoder_dense
)
self.cross_entropy = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
self.opt = keras.optimizers.Adam(0.05, clipnorm=5.0)
# predict decoder
self.decoder_eval = tfa.seq2seq.BasicDecoder(
cell=self.decoder_cell,
sampler=tfa.seq2seq.sampler.GreedyEmbeddingSampler(), # sampler for predict
output_layer=decoder_dense
)
# prediction restriction
self.max_pred_len = max_pred_len
self.start_token = start_token
self.end_token = end_token
def encode(self, x):
o = self.enc_embeddings(x)
init_s = [tf.zeros((x.shape[0], self.units)), tf.zeros((x.shape[0], self.units))]
o, h, c = self.encoder(o, initial_state=init_s)
return o, h, c
def set_attention(self, x):
o, h, c = self.encode(x)
# encoder output for attention to focus
self.attention.setup_memory(o)
# wrap state by attention wrapper
s = self.decoder_cell.get_initial_state(batch_size=x.shape[0], dtype=tf.float32).clone(cell_state=[h, c])
return s
def inference(self, x, return_align=False):
s = self.set_attention(x)
done, i, s = self.decoder_eval.initialize(
self.dec_embeddings.variables[0],
start_tokens=tf.fill([x.shape[0], ], self.start_token),
end_token=self.end_token,
initial_state=s,
)
pred_id = np.zeros((x.shape[0], self.max_pred_len), dtype=np.int32)
for l in range(self.max_pred_len):
o, s, i, done = self.decoder_eval.step(
time=l, inputs=i, state=s, training=False)
pred_id[:, l] = o.sample_id
if return_align:
return np.transpose(s.alignment_history.stack().numpy(), (1, 0, 2))
else:
s.alignment_history.mark_used() # otherwise gives warning
return pred_id
def train_logits(self, x, y, seq_len):
s = self.set_attention(x)
dec_in = y[:, :-1] # ignore <EOS>
dec_emb_in = self.dec_embeddings(dec_in)
o, _, _ = self.decoder_train(dec_emb_in, s, sequence_length=seq_len)
logits = o.rnn_output
return logits
def step(self, x, y, seq_len):
with tf.GradientTape() as tape:
logits = self.train_logits(x, y, seq_len)
dec_out = y[:, 1:] # ignore <GO>
loss = self.cross_entropy(dec_out, logits)
grads = tape.gradient(loss, self.trainable_variables)
self.opt.apply_gradients(zip(grads, self.trainable_variables))
return loss.numpy()
def train():
# get and process data
data = utils.DateData(2000)
print("Chinese time order: yy/mm/dd ", data.date_cn[:3], "\nEnglish time order: dd/M/yyyy ", data.date_en[:3])
print("vocabularies: ", data.vocab)
print("x index sample: \n{}\n{}".format(data.idx2str(data.x[0]), data.x[0]),
"\ny index sample: \n{}\n{}".format(data.idx2str(data.y[0]), data.y[0]))
model = Seq2Seq(
data.num_word, data.num_word, emb_dim=12, units=14, attention_layer_size=16,
max_pred_len=11, start_token=data.start_token, end_token=data.end_token)
# training
for t in range(1000):
bx, by, decoder_len = data.sample(64)
loss = model.step(bx, by, decoder_len)
if t % 70 == 0:
target = data.idx2str(by[0, 1:-1])
pred = model.inference(bx[0:1])
res = data.idx2str(pred[0])
src = data.idx2str(bx[0])
print(
"t: ", t,
"| loss: %.5f" % loss,
"| input: ", src,
"| target: ", target,
"| inference: ", res,
)
pkl_data = {"i2v": data.i2v, "x": data.x[:6], "y": data.y[:6], "align": model.inference(data.x[:6], return_align=True)}
with open("./visual/tmp/attention_align.pkl", "wb") as f:
pickle.dump(pkl_data, f)
if __name__ == "__main__":
train()
工具类(utils.py):
import numpy as np
import datetime
import os
import requests
import pandas as pd
import re
import itertools
PAD_ID = 0
class DateData:
def __init__(self, n):
np.random.seed(1)
self.date_cn = []
self.date_en = []
for timestamp in np.random.randint(143835585, 2043835585, n):
date = datetime.datetime.fromtimestamp(timestamp)
self.date_cn.append(date.strftime("%y-%m-%d"))
self.date_en.append(date.strftime("%d/%b/%Y"))
self.vocab = set(
[str(i) for i in range(0, 10)] + ["-", "/", "<GO>", "<EOS>"] + [
i.split("/")[1] for i in self.date_en])
self.v2i = {v: i for i, v in enumerate(sorted(list(self.vocab)), start=1)}
self.v2i["<PAD>"] = PAD_ID
self.vocab.add("<PAD>")
self.i2v = {i: v for v, i in self.v2i.items()}
self.x, self.y = [], []
for cn, en in zip(self.date_cn, self.date_en):
self.x.append([self.v2i[v] for v in cn])
self.y.append(
[self.v2i["<GO>"], ] + [self.v2i[v] for v in en[:3]] + [
self.v2i[en[3:6]], ] + [self.v2i[v] for v in en[6:]] + [
self.v2i["<EOS>"], ])
self.x, self.y = np.array(self.x), np.array(self.y)
self.start_token = self.v2i["<GO>"]
self.end_token = self.v2i["<EOS>"]
def sample(self, n=64):
bi = np.random.randint(0, len(self.x), size=n)
bx, by = self.x[bi], self.y[bi]
decoder_len = np.full((len(bx),), by.shape[1] - 1, dtype=np.int32)
return bx, by, decoder_len
def idx2str(self, idx):
x = []
for i in idx:
x.append(self.i2v[i])
if i == self.end_token:
break
return "".join(x)
@property
def num_word(self):
return len(self.vocab)
def pad_zero(seqs, max_len):
padded = np.full((len(seqs), max_len), fill_value=PAD_ID, dtype=np.int32)
for i, seq in enumerate(seqs):
padded[i, :len(seq)] = seq
return padded
def maybe_download_mrpc(save_dir="./MRPC/", proxy=None):
train_url = 'https://mofanpy.com/static/files/MRPC/msr_paraphrase_train.txt'
test_url = 'https://mofanpy.com/static/files/MRPC/msr_paraphrase_test.txt'
os.makedirs(save_dir, exist_ok=True)
proxies = {"http": proxy, "https": proxy}
for url in [train_url, test_url]:
raw_path = os.path.join(save_dir, url.split("/")[-1])
if not os.path.isfile(raw_path):
print("downloading from %s" % url)
r = requests.get(url, proxies=proxies)
with open(raw_path, "w", encoding="utf-8") as f:
f.write(r.text.replace('"', "<QUOTE>"))
print("completed")
def _text_standardize(text):
text = re.sub(r'—', '-', text)
text = re.sub(r'–', '-', text)
text = re.sub(r'―', '-', text)
text = re.sub(r" \d+(,\d+)?(\.\d+)? ", " <NUM> ", text)
text = re.sub(r" \d+-+?\d*", " <NUM>-", text)
return text.strip()
def _process_mrpc(dir="./MRPC", rows=None):
data = {"train": None, "test": None}
files = os.listdir(dir)
for f in files:
df = pd.read_csv(os.path.join(dir, f), sep='\t', nrows=rows)
k = "train" if "train" in f else "test"
data[k] = {"is_same": df.iloc[:, 0].values, "s1": df["#1 String"].values, "s2": df["#2 String"].values}
vocab = set()
for n in ["train", "test"]:
for m in ["s1", "s2"]:
for i in range(len(data[n][m])):
data[n][m][i] = _text_standardize(data[n][m][i].lower())
cs = data[n][m][i].split(" ")
vocab.update(set(cs))
v2i = {v: i for i, v in enumerate(sorted(vocab), start=1)}
v2i["<PAD>"] = PAD_ID
v2i["<MASK>"] = len(v2i)
v2i["<SEP>"] = len(v2i)
v2i["<GO>"] = len(v2i)
i2v = {i: v for v, i in v2i.items()}
for n in ["train", "test"]:
for m in ["s1", "s2"]:
data[n][m+"id"] = [[v2i[v] for v in c.split(" ")] for c in data[n][m]]
return data, v2i, i2v
class MRPCData:
num_seg = 3
pad_id = PAD_ID
def __init__(self, data_dir="./MRPC/", rows=None, proxy=None):
maybe_download_mrpc(save_dir=data_dir, proxy=proxy)
data, self.v2i, self.i2v = _process_mrpc(data_dir, rows)
self.max_len = max(
[len(s1) + len(s2) + 3 for s1, s2 in zip(
data["train"]["s1id"] + data["test"]["s1id"], data["train"]["s2id"] + data["test"]["s2id"])])
self.xlen = np.array([
[
len(data["train"]["s1id"][i]), len(data["train"]["s2id"][i])
] for i in range(len(data["train"]["s1id"]))], dtype=int)
x = [
[self.v2i["<GO>"]] + data["train"]["s1id"][i] + [self.v2i["<SEP>"]] + data["train"]["s2id"][i] + [self.v2i["<SEP>"]]
for i in range(len(self.xlen))
]
self.x = pad_zero(x, max_len=self.max_len)
self.nsp_y = data["train"]["is_same"][:, None]
self.seg = np.full(self.x.shape, self.num_seg-1, np.int32)
for i in range(len(x)):
si = self.xlen[i][0] + 2
self.seg[i, :si] = 0
si_ = si + self.xlen[i][1] + 1
self.seg[i, si:si_] = 1
self.word_ids = np.array(list(set(self.i2v.keys()).difference(
[self.v2i[v] for v in ["<PAD>", "<MASK>", "<SEP>"]])))
def sample(self, n):
bi = np.random.randint(0, self.x.shape[0], size=n)
bx, bs, bl, by = self.x[bi], self.seg[bi], self.xlen[bi], self.nsp_y[bi]
return bx, bs, bl, by
@property
def num_word(self):
return len(self.v2i)
@property
def mask_id(self):
return self.v2i["<MASK>"]
class MRPCSingle:
pad_id = PAD_ID
def __init__(self, data_dir="./MRPC/", rows=None, proxy=None):
maybe_download_mrpc(save_dir=data_dir, proxy=proxy)
data, self.v2i, self.i2v = _process_mrpc(data_dir, rows)
self.max_len = max([len(s) + 2 for s in data["train"]["s1id"] + data["train"]["s2id"]])
x = [
[self.v2i["<GO>"]] + data["train"]["s1id"][i] + [self.v2i["<SEP>"]]
for i in range(len(data["train"]["s1id"]))
]
x += [
[self.v2i["<GO>"]] + data["train"]["s2id"][i] + [self.v2i["<SEP>"]]
for i in range(len(data["train"]["s2id"]))
]
self.x = pad_zero(x, max_len=self.max_len)
self.word_ids = np.array(list(set(self.i2v.keys()).difference([self.v2i["<PAD>"]])))
def sample(self, n):
bi = np.random.randint(0, self.x.shape[0], size=n)
bx = self.x[bi]
return bx
@property
def num_word(self):
return len(self.v2i)
class Dataset:
def __init__(self, x, y, v2i, i2v):
self.x, self.y = x, y
self.v2i, self.i2v = v2i, i2v
self.vocab = v2i.keys()
def sample(self, n):
b_idx = np.random.randint(0, len(self.x), n)
bx, by = self.x[b_idx], self.y[b_idx]
return bx, by
@property
def num_word(self):
return len(self.v2i)
def process_w2v_data(corpus, skip_window=2, method="skip_gram"):
all_words = [sentence.split(" ") for sentence in corpus]
all_words = np.array(list(itertools.chain(*all_words)))
# vocab sort by decreasing frequency for the negative sampling below (nce_loss).
vocab, v_count = np.unique(all_words, return_counts=True)
vocab = vocab[np.argsort(v_count)[::-1]]
print("all vocabularies sorted from more frequent to less frequent:\n", vocab)
v2i = {v: i for i, v in enumerate(vocab)}
i2v = {i: v for v, i in v2i.items()}
# pair data
pairs = []
js = [i for i in range(-skip_window, skip_window + 1) if i != 0]
for c in corpus:
words = c.split(" ")
w_idx = [v2i[w] for w in words]
if method == "skip_gram":
for i in range(len(w_idx)):
for j in js:
if i + j < 0 or i + j >= len(w_idx):
continue
pairs.append((w_idx[i], w_idx[i + j])) # (center, context) or (feature, target)
elif method.lower() == "cbow":
for i in range(skip_window, len(w_idx) - skip_window):
context = []
for j in js:
context.append(w_idx[i + j])
pairs.append(context + [w_idx[i]]) # (contexts, center) or (feature, target)
else:
raise ValueError
pairs = np.array(pairs)
print("5 example pairs:\n", pairs[:5])
if method.lower() == "skip_gram":
x, y = pairs[:, 0], pairs[:, 1]
elif method.lower() == "cbow":
x, y = pairs[:, :-1], pairs[:, -1]
else:
raise ValueError
return Dataset(x, y, v2i, i2v)
def set_soft_gpu(soft_gpu):
import tensorflow as tf
if soft_gpu:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
# Currently, memory growth needs to be the same across GPUs
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.experimental.list_logical_devices('GPU')
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
训练结果:
t: 0 | loss: 3.29403 | input: 89-05-25 | target: 25/May/1989 | inference: 00000000000
t: 70 | loss: 0.41608 | input: 03-09-13 | target: 13/Sep/2003 | inference: 13/Jan/2000<EOS>
t: 140 | loss: 0.01793 | input: 92-06-01 | target: 01/Jun/1992 | inference: 01/Jun/1992<EOS>
t: 210 | loss: 0.00309 | input: 23-01-28 | target: 28/Jan/2023 | inference: 28/Jan/2023<EOS>
...
t: 910 | loss: 0.00003 | input: 11-09-13 | target: 13/Sep/2011 | inference: 13/Sep/2011<EOS>
t: 980 | loss: 0.00003 | input: 06-08-10 | target: 10/Aug/2006 | inference: 10/Aug/2006<EOS>