这次作业难度还是挺大的足足写了1天,主要任务是将人类的各种格式日期转换成yyyy-mm-dd的格式输出。
首先导入数据,然后限定人类输入最大长度为30,机器输出长度为10.对x,y进行独热编码。
import keras
import nmt_utils
import os
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 忽略警告
dataset, human, machine, inv_machine = nmt_utils.load_dataset(10000)
print(dataset[:10])
print(human)
print(machine)
print(inv_machine)
tx = 30
ty = 10 # 30人类输入最大长度 10机器输出长度
timestep = tx
n_a = 32
X, Y, Xoh, Yoh = nmt_utils.preprocess_data(dataset, human, machine, tx, ty)
print("X.shape:", X.shape)
print("Y.shape:", Y.shape)
print("Xoh.shape:", Xoh.shape)
print("Yoh.shape:", Yoh.shape)
观察打印信息,对比x和y,了解输入输出格式以及编码方法。
[('9 may 1998', '1998-05-09'), ('10.11.19', '2019-11-10'), ('9/10/70', '1970-09-10'), ('saturday april 28 1990', '1990-04-28'), ('thursday january 26 1995', '1995-01-26'), ('monday march 7 1983', '1983-03-07'), ('sunday may 22 1988', '1988-05-22'), ('08 jul 2008', '2008-07-08'), ('8 sep 1999', '1999-09-08'), ('thursday january 1 1981', '1981-01-01')]
{' ': 0, '.': 1, '/': 2, '0': 3, '1': 4, '2': 5, '3': 6, '4': 7, '5': 8, '6': 9, '7': 10, '8': 11, '9': 12, 'a': 13, 'b': 14, 'c': 15, 'd': 16, 'e': 17, 'f': 18, 'g': 19, 'h': 20, 'i': 21, 'j': 22, 'l': 23, 'm': 24, 'n': 25, 'o': 26, 'p': 27, 'r': 28, 's': 29, 't': 30, 'u': 31, 'v': 32, 'w': 33, 'y': 34, '<unk>': 35, '<pad>': 36}
{'-': 0, '0': 1, '1': 2, '2': 3, '3': 4, '4': 5, '5': 6, '6': 7, '7': 8, '8': 9, '9': 10}
{0: '-', 1: '0', 2: '1', 3: '2', 4: '3', 5: '4', 6: '5', 7: '6', 8: '7', 9: '8', 10: '9'}
下面就是建模了,计算图如下
注意力模型的示意图:
根据模型,首先搭建最底层的双向循环网络:
def pre_attention(x, n_a):
a = keras.layers.Bidirectional(keras.layers.LSTM(n_a, return_sequences=True))(x) # [batch, timesteps, feature]
# print("a:",a)
return a
然后搭建注意力模型以及上层LSTM。
这里要注意:上层LSTM要通过每个节点输出的S<t-1>和下层输出的a<t>结合生成α(注意力权重),用for循环来调用每个节点。要实现共享权重就要设为全局变量,我一开始没有设置全局变量每次循环就会创建一个节点,变成了类似CNN那样,产生了300多万的权重。
# 共享权重,关键点与CNN不同
repeat = keras.layers.RepeatVector(timestep)
concatenator = keras.layers.Concatenate()
dence_tanh = keras.layers.Dense(10, activation="tanh")
dence_relu = keras.layers.Dense(1, activation="relu")
act_softmax = keras.layers.Activation(activation="softmax")
dot = keras.layers.Dot(axes=1)
Lstm2 = keras.layers.LSTM(n_a * 2, return_state=True)
def post_attention(s_pre, c_pre, a, n_a, timestep):
s = repeat(s_pre)
# print("s:", s)
dense_input = concatenator([a, s])
# print("dense_input:", dense_input)
e1 = dence_tanh(dense_input)
# print("e1:", e1)
e2 = dence_relu(e1)
# print("e2:", e2)
alpha = act_softmax(e2)
# print("alpha:", alpha)
context = dot([alpha, a])
# print("context:", context)
s, _, c = Lstm2(context, initial_state=[s_pre, c_pre]) # lstm1, state_h, state_c = LSTM(1, return_state=True)
# print("s:", s)
# print("c:", c)
return s, c
cal_output = keras.layers.Dense(Yoh.shape[2], activation="softmax")
def model(x, y, n_a=32):
x_input = keras.Input(shape=(x.shape[1], x.shape[2]))
a = pre_attention(x_input, n_a)
s0 = keras.layers.Input(shape=(2 * n_a,))
c0 = keras.layers.Input(shape=(2 * n_a,))
s = s0
c = c0
# print("s0:", s)
# print("c0:", c)
outputs = []
for i in range(y.shape[1]):
s, c = post_attention(s, c, a, n_a, timestep)
y_hat = cal_output(s)
outputs.append(y_hat)
model = keras.Model([x_input, s0, c0], outputs)
return model
运行模型,并保存:
model = model(Xoh, Yoh, n_a)
model.summary()
model.compile(optimizer=keras.optimizers.Adam(0.005, decay=0.01), loss="categorical_crossentropy", metrics=["accuracy"])
s0 = np.zeros((Xoh.shape[0], n_a * 2))
c0 = np.zeros((Xoh.shape[0], n_a * 2))
model.fit([Xoh, s0, c0], list(Yoh.swapaxes(0, 1)), 1024, 150) # swapaxes()交换轴序
model.save("week3_model.h5")
model.save_weights("week3_weight.h5")
运行过程:可以看到参数数量非常少
==================================================================================================
Total params: 52,960
Trainable params: 52,960
Non-trainable params: 0
__________________________________________________________________________________________________
Epoch 1/150
1024/10000 [==>...........................] - ETA: 59s - loss: 24.5130 - dense_3_loss: 2.4198 - dense_3_acc: 0.2979 - dense_3_acc_1: 0.0020 - dense_3_acc_2: 0.0947 - dense_3_acc_3: 0.0889 - dense_3_acc_4: 0.0000e+00 - dense_3_acc_5: 0.0459 - dense_3_acc_6: 0.0850 - dense_3_acc_7: 0.0000e+00 - dense_3_acc_8: 0.1777 - dense_3_acc_9: 0.0801
2048/10000 [=====>........................] - ETA: 26s - loss: 23.2616 - dense_3_loss: 2.5393 - dense_3_acc: 0.1763 - dense_3_acc_1: 0.2271 - dense_3_acc_2: 0.1426 - dense_3_acc_3: 0.1128 - dense_3_acc_4: 0.0029 - dense_3_acc_5: 0.3960 - dense_3_acc_6: 0.0918 - dense_3_acc_7: 0.0151 - dense_3_acc_8: 0.2266 - dense_3_acc_9: 0.0957
3072/10000 [========>.....................] - ETA: 15s - loss: 22.6532 - dense_3_loss: 2.6770 - dense_3_acc: 0.1312 - dense_3_acc_1: 0.2731 - dense_3_acc_2: 0.1276 - dense_3_acc_3: 0.0801 - dense_3_acc_4: 0.3216 - dense_3_acc_5: 0.2728 - dense_3_acc_6: 0.0628 - dense_3_acc_7: 0.3376 - dense_3_acc_8: 0.1543 - dense_3_acc_9: 0.0638
4096/10000 [===========>..................] - ETA: 10s - loss: 22.3197 - dense_3_loss: 2.7610 - dense_3_acc: 0.1082 - dense_3_acc_1: 0.3008 - dense_3_acc_2: 0.1350 - dense_3_acc_3: 0.0867 - dense_3_acc_4: 0.3525 - dense_3_acc_5: 0.3008 - dense_3_acc_6: 0.0535 - dense_3_acc_7: 0.4016 - dense_3_acc_8: 0.1450 - dense_3_acc_9: 0.0581
5120/10000 [==============>...............] - ETA: 6s - loss: 22.1142 - dense_3_loss: 2.8040 - dense_3_acc: 0.0896 - dense_3_acc_1: 0.3158 - dense_3_acc_2: 0.1428 - dense_3_acc_3: 0.0840 - dense_3_acc_4: 0.4156 - dense_3_acc_5: 0.2656 - dense_3_acc_6: 0.0453 - dense_3_acc_7: 0.5109 - dense_3_acc_8: 0.1191 - dense_3_acc_9: 0.0475
6144/10000 [=================>............] - ETA: 4s - loss: 21.9315 - dense_3_loss: 2.8214 - dense_3_acc: 0.0760 - dense_3_acc_1: 0.3332 - dense_3_acc_2: 0.1510 - dense_3_acc_3: 0.0807 - dense_3_acc_4: 0.4600 - dense_3_acc_5: 0.2524 - dense_3_acc_6: 0.0399 - dense_3_acc_7: 0.5679 - dense_3_acc_8: 0.1069 - dense_3_acc_9: 0.0422
7168/10000 [====================>.........] - ETA: 2s - loss: 21.7856 - dense_3_loss: 2.8308 - dense_3_acc: 0.0713 - dense_3_acc_1: 0.3467 - dense_3_acc_2: 0.1565 - dense_3_acc_3: 0.0822 - dense_3_acc_4: 0.4824 - dense_3_acc_5: 0.2467 - dense_3_acc_6: 0.0361 - dense_3_acc_7: 0.6081 - dense_3_acc_8: 0.0975 - dense_3_acc_9: 0.0389
8192/10000 [=======================>......] - ETA: 1s - loss: 21.6628 - dense_3_loss: 2.8321 - dense_3_acc: 0.0820 - dense_3_acc_1: 0.3550 - dense_3_acc_2: 0.1622 - dense_3_acc_3: 0.0801 - dense_3_acc_4: 0.5184 - dense_3_acc_5: 0.2314 - dense_3_acc_6: 0.0326 - dense_3_acc_7: 0.6469 - dense_3_acc_8: 0.0887 - dense_3_acc_9: 0.0354
9216/10000 [==========================>...] - ETA: 0s - loss: 21.5547 - dense_3_loss: 2.8269 - dense_3_acc: 0.1066 - dense_3_acc_1: 0.3596 - dense_3_acc_2: 0.1673 - dense_3_acc_3: 0.0814 - dense_3_acc_4: 0.5276 - dense_3_acc_5: 0.2266 - dense_3_acc_6: 0.0335 - dense_3_acc_7: 0.6656 - dense_3_acc_8: 0.0863 - dense_3_acc_9: 0.0352
10000/10000 [==============================] - 8s 769us/step - loss: 21.4706 - dense_3_loss: 2.8235 - dense_3_acc: 0.1183 - dense_3_acc_1: 0.3710 - dense_3_acc_2: 0.1739 - dense_3_acc_3: 0.0880 - dense_3_acc_4: 0.5086 - dense_3_acc_5: 0.2389 - dense_3_acc_6: 0.0361 - dense_3_acc_7: 0.6608 - dense_3_acc_8: 0.0906 - dense_3_acc_9: 0.0373
......
Epoch 150/150
1024/10000 [==>...........................] - ETA: 0s - loss: 0.3782 - dense_3_loss: 0.1019 - dense_3_acc: 0.9990 - dense_3_acc_1: 1.0000 - dense_3_acc_2: 0.9990 - dense_3_acc_3: 0.9961 - dense_3_acc_4: 1.0000 - dense_3_acc_5: 0.9912 - dense_3_acc_6: 0.9688 - dense_3_acc_7: 1.0000 - dense_3_acc_8: 0.9902 - dense_3_acc_9: 0.9805
2048/10000 [=====>........................] - ETA: 0s - loss: 0.3740 - dense_3_loss: 0.0982 - dense_3_acc: 0.9995 - dense_3_acc_1: 1.0000 - dense_3_acc_2: 0.9966 - dense_3_acc_3: 0.9971 - dense_3_acc_4: 1.0000 - dense_3_acc_5: 0.9907 - dense_3_acc_6: 0.9702 - dense_3_acc_7: 1.0000 - dense_3_acc_8: 0.9927 - dense_3_acc_9: 0.9824
3072/10000 [========>.....................] - ETA: 0s - loss: 0.3793 - dense_3_loss: 0.1045 - dense_3_acc: 0.9990 - dense_3_acc_1: 0.9997 - dense_3_acc_2: 0.9958 - dense_3_acc_3: 0.9977 - dense_3_acc_4: 1.0000 - dense_3_acc_5: 0.9909 - dense_3_acc_6: 0.9717 - dense_3_acc_7: 1.0000 - dense_3_acc_8: 0.9909 - dense_3_acc_9: 0.9775
4096/10000 [===========>..................] - ETA: 0s - loss: 0.3723 - dense_3_loss: 0.1016 - dense_3_acc: 0.9993 - dense_3_acc_1: 0.9998 - dense_3_acc_2: 0.9963 - dense_3_acc_3: 0.9978 - dense_3_acc_4: 1.0000 - dense_3_acc_5: 0.9919 - dense_3_acc_6: 0.9717 - dense_3_acc_7: 1.0000 - dense_3_acc_8: 0.9919 - dense_3_acc_9: 0.9771
5120/10000 [==============>...............] - ETA: 0s - loss: 0.3729 - dense_3_loss: 0.1022 - dense_3_acc: 0.9994 - dense_3_acc_1: 0.9998 - dense_3_acc_2: 0.9967 - dense_3_acc_3: 0.9975 - dense_3_acc_4: 1.0000 - dense_3_acc_5: 0.9924 - dense_3_acc_6: 0.9709 - dense_3_acc_7: 1.0000 - dense_3_acc_8: 0.9930 - dense_3_acc_9: 0.9766
6144/10000 [=================>............] - ETA: 0s - loss: 0.3841 - dense_3_loss: 0.1050 - dense_3_acc: 0.9992 - dense_3_acc_1: 0.9997 - dense_3_acc_2: 0.9964 - dense_3_acc_3: 0.9967 - dense_3_acc_4: 1.0000 - dense_3_acc_5: 0.9932 - dense_3_acc_6: 0.9701 - dense_3_acc_7: 1.0000 - dense_3_acc_8: 0.9932 - dense_3_acc_9: 0.9753
7168/10000 [====================>.........] - ETA: 0s - loss: 0.3887 - dense_3_loss: 0.1053 - dense_3_acc: 0.9993 - dense_3_acc_1: 0.9997 - dense_3_acc_2: 0.9964 - dense_3_acc_3: 0.9965 - dense_3_acc_4: 1.0000 - dense_3_acc_5: 0.9926 - dense_3_acc_6: 0.9701 - dense_3_acc_7: 1.0000 - dense_3_acc_8: 0.9925 - dense_3_acc_9: 0.9750
8192/10000 [=======================>......] - ETA: 0s - loss: 0.3896 - dense_3_loss: 0.1053 - dense_3_acc: 0.9994 - dense_3_acc_1: 0.9998 - dense_3_acc_2: 0.9967 - dense_3_acc_3: 0.9967 - dense_3_acc_4: 1.0000 - dense_3_acc_5: 0.9927 - dense_3_acc_6: 0.9694 - dense_3_acc_7: 1.0000 - dense_3_acc_8: 0.9922 - dense_3_acc_9: 0.9752
9216/10000 [==========================>...] - ETA: 0s - loss: 0.3870 - dense_3_loss: 0.1053 - dense_3_acc: 0.9995 - dense_3_acc_1: 0.9998 - dense_3_acc_2: 0.9966 - dense_3_acc_3: 0.9969 - dense_3_acc_4: 1.0000 - dense_3_acc_5: 0.9925 - dense_3_acc_6: 0.9696 - dense_3_acc_7: 1.0000 - dense_3_acc_8: 0.9925 - dense_3_acc_9: 0.9752C:\Users\admin\Anaconda3\envs\tf\lib\site-packages\keras\engine\topology.py:2364:
导入模型,测试自己的数据:
import keras
import nmt_utils
import os
import numpy as np
import matplotlib.pyplot as plt
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 忽略警告
dataset, human, machine, inv_machine = nmt_utils.load_dataset(10000)
tx = 30
ty = 10 # 30人类输入最大长度 10机器输出长度
timestep = tx
n_a = 32
X, Y, Xoh, Yoh = nmt_utils.preprocess_data(dataset, human, machine, tx, ty)
model=keras.models.load_model("week3_model.h5")
model.load_weights("week3_weight.h5")
#model.summary()
EXAMPLES = ['3 May 1979', '5 April 09', '21th of Aug 2016', 'Tue 10 Jul 2007', 'Saturday May 9 2018', 'Mrch 3 2001',
'March 3rd 2001', '1 March 2001','1.25.1997','April.26.1990']
for example in EXAMPLES:
source = nmt_utils.string_to_int(example, tx, human)
source = np.array(list(map(lambda x: nmt_utils.to_categorical(x, num_classes=len(human)), source)))
source = np.expand_dims(source, axis=0)
s0 = np.zeros((Xoh.shape[0], n_a * 2))
c0 = np.zeros((Xoh.shape[0], n_a * 2))
prediction = model.predict([source, s0, c0])
prediction = np.argmax(prediction, axis=-1)
output = [inv_machine[int(i)] for i in prediction]
print("source:", example," output:", ''.join(output))
attention_map=nmt_utils.plot_attention_map(model,human,inv_machine,'3 May 1979',64)
输出:可以看出对各种格式的输入都能很好的识别。
source: 3 May 1979 output: 1979-05-03
source: 5 April 09 output: 2009-05-05
source: 21th of Aug 2016 output: 2016-08-00
source: Tue 10 Jul 2007 output: 2007-07-00
source: Saturday May 9 2018 output: 2018-05-09
source: Mrch 3 2001 output: 2001-03-03
source: March 3rd 2001 output: 2001-03-03
source: 1 March 2001 output: 2001-03-01
source: 1.25.1997 output: 1999-10-25
source: April.26.1990 output: 1990-04-26