吴恩达DeepLearning第五部分作业week3(1) 机器翻译

这次作业难度还是挺大的足足写了1天,主要任务是将人类的各种格式日期转换成yyyy-mm-dd的格式输出。

首先导入数据,然后限定人类输入最大长度为30,机器输出长度为10.对x,y进行独热编码。

import keras
import nmt_utils
import os
import numpy as np

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # 忽略警告

dataset, human, machine, inv_machine = nmt_utils.load_dataset(10000)
print(dataset[:10])
print(human)
print(machine)
print(inv_machine)
tx = 30
ty = 10  # 30人类输入最大长度 10机器输出长度
timestep = tx
n_a = 32

X, Y, Xoh, Yoh = nmt_utils.preprocess_data(dataset, human, machine, tx, ty)
print("X.shape:", X.shape)
print("Y.shape:", Y.shape)
print("Xoh.shape:", Xoh.shape)
print("Yoh.shape:", Yoh.shape)

观察打印信息,对比x和y,了解输入输出格式以及编码方法。

[('9 may 1998', '1998-05-09'), ('10.11.19', '2019-11-10'), ('9/10/70', '1970-09-10'), ('saturday april 28 1990', '1990-04-28'), ('thursday january 26 1995', '1995-01-26'), ('monday march 7 1983', '1983-03-07'), ('sunday may 22 1988', '1988-05-22'), ('08 jul 2008', '2008-07-08'), ('8 sep 1999', '1999-09-08'), ('thursday january 1 1981', '1981-01-01')]

{' ': 0, '.': 1, '/': 2, '0': 3, '1': 4, '2': 5, '3': 6, '4': 7, '5': 8, '6': 9, '7': 10, '8': 11, '9': 12, 'a': 13, 'b': 14, 'c': 15, 'd': 16, 'e': 17, 'f': 18, 'g': 19, 'h': 20, 'i': 21, 'j': 22, 'l': 23, 'm': 24, 'n': 25, 'o': 26, 'p': 27, 'r': 28, 's': 29, 't': 30, 'u': 31, 'v': 32, 'w': 33, 'y': 34, '<unk>': 35, '<pad>': 36}


{'-': 0, '0': 1, '1': 2, '2': 3, '3': 4, '4': 5, '5': 6, '6': 7, '7': 8, '8': 9, '9': 10}
{0: '-', 1: '0', 2: '1', 3: '2', 4: '3', 5: '4', 6: '5', 7: '6', 8: '7', 9: '8', 10: '9'}

 下面就是建模了,计算图如下

注意力模型的示意图:

 根据模型,首先搭建最底层的双向循环网络:

def pre_attention(x, n_a):
    a = keras.layers.Bidirectional(keras.layers.LSTM(n_a, return_sequences=True))(x)  # [batch, timesteps, feature]
    # print("a:",a)
    return a

然后搭建注意力模型以及上层LSTM。

这里要注意:上层LSTM要通过每个节点输出的S<t-1>和下层输出的a<t>结合生成α(注意力权重),用for循环来调用每个节点。要实现共享权重就要设为全局变量,我一开始没有设置全局变量每次循环就会创建一个节点,变成了类似CNN那样,产生了300多万的权重。

# 共享权重,关键点与CNN不同
repeat = keras.layers.RepeatVector(timestep)
concatenator = keras.layers.Concatenate()
dence_tanh = keras.layers.Dense(10, activation="tanh")
dence_relu = keras.layers.Dense(1, activation="relu")
act_softmax = keras.layers.Activation(activation="softmax")
dot = keras.layers.Dot(axes=1)
Lstm2 = keras.layers.LSTM(n_a * 2, return_state=True)


def post_attention(s_pre, c_pre, a, n_a, timestep):
    s = repeat(s_pre)
    # print("s:", s)
    dense_input = concatenator([a, s])
    # print("dense_input:", dense_input)
    e1 = dence_tanh(dense_input)
    # print("e1:", e1)
    e2 = dence_relu(e1)
    # print("e2:", e2)
    alpha = act_softmax(e2)
    # print("alpha:", alpha)
    context = dot([alpha, a])
    # print("context:", context)
    s, _, c = Lstm2(context, initial_state=[s_pre, c_pre])  # lstm1, state_h, state_c = LSTM(1, return_state=True)
    # print("s:", s)
    # print("c:", c)
    return s, c

cal_output = keras.layers.Dense(Yoh.shape[2], activation="softmax")


def model(x, y, n_a=32):
    x_input = keras.Input(shape=(x.shape[1], x.shape[2]))
    a = pre_attention(x_input, n_a)
    s0 = keras.layers.Input(shape=(2 * n_a,))
    c0 = keras.layers.Input(shape=(2 * n_a,))
    s = s0
    c = c0
    # print("s0:", s)
    # print("c0:", c)
    outputs = []
    for i in range(y.shape[1]):
        s, c = post_attention(s, c, a, n_a, timestep)
        y_hat = cal_output(s)
        outputs.append(y_hat)
    model = keras.Model([x_input, s0, c0], outputs)
    return model

运行模型,并保存:

model = model(Xoh, Yoh, n_a)
model.summary()
model.compile(optimizer=keras.optimizers.Adam(0.005, decay=0.01), loss="categorical_crossentropy", metrics=["accuracy"])

s0 = np.zeros((Xoh.shape[0], n_a * 2))
c0 = np.zeros((Xoh.shape[0], n_a * 2))
model.fit([Xoh, s0, c0], list(Yoh.swapaxes(0, 1)), 1024, 150)  # swapaxes()交换轴序
model.save("week3_model.h5")
model.save_weights("week3_weight.h5")

 运行过程:可以看到参数数量非常少

==================================================================================================

Total params: 52,960
Trainable params: 52,960
Non-trainable params: 0
__________________________________________________________________________________________________


Epoch 1/150

 1024/10000 [==>...........................] - ETA: 59s - loss: 24.5130 - dense_3_loss: 2.4198 - dense_3_acc: 0.2979 - dense_3_acc_1: 0.0020 - dense_3_acc_2: 0.0947 - dense_3_acc_3: 0.0889 - dense_3_acc_4: 0.0000e+00 - dense_3_acc_5: 0.0459 - dense_3_acc_6: 0.0850 - dense_3_acc_7: 0.0000e+00 - dense_3_acc_8: 0.1777 - dense_3_acc_9: 0.0801
 2048/10000 [=====>........................] - ETA: 26s - loss: 23.2616 - dense_3_loss: 2.5393 - dense_3_acc: 0.1763 - dense_3_acc_1: 0.2271 - dense_3_acc_2: 0.1426 - dense_3_acc_3: 0.1128 - dense_3_acc_4: 0.0029 - dense_3_acc_5: 0.3960 - dense_3_acc_6: 0.0918 - dense_3_acc_7: 0.0151 - dense_3_acc_8: 0.2266 - dense_3_acc_9: 0.0957        
 3072/10000 [========>.....................] - ETA: 15s - loss: 22.6532 - dense_3_loss: 2.6770 - dense_3_acc: 0.1312 - dense_3_acc_1: 0.2731 - dense_3_acc_2: 0.1276 - dense_3_acc_3: 0.0801 - dense_3_acc_4: 0.3216 - dense_3_acc_5: 0.2728 - dense_3_acc_6: 0.0628 - dense_3_acc_7: 0.3376 - dense_3_acc_8: 0.1543 - dense_3_acc_9: 0.0638
 4096/10000 [===========>..................] - ETA: 10s - loss: 22.3197 - dense_3_loss: 2.7610 - dense_3_acc: 0.1082 - dense_3_acc_1: 0.3008 - dense_3_acc_2: 0.1350 - dense_3_acc_3: 0.0867 - dense_3_acc_4: 0.3525 - dense_3_acc_5: 0.3008 - dense_3_acc_6: 0.0535 - dense_3_acc_7: 0.4016 - dense_3_acc_8: 0.1450 - dense_3_acc_9: 0.0581
 5120/10000 [==============>...............] - ETA: 6s - loss: 22.1142 - dense_3_loss: 2.8040 - dense_3_acc: 0.0896 - dense_3_acc_1: 0.3158 - dense_3_acc_2: 0.1428 - dense_3_acc_3: 0.0840 - dense_3_acc_4: 0.4156 - dense_3_acc_5: 0.2656 - dense_3_acc_6: 0.0453 - dense_3_acc_7: 0.5109 - dense_3_acc_8: 0.1191 - dense_3_acc_9: 0.0475 
 6144/10000 [=================>............] - ETA: 4s - loss: 21.9315 - dense_3_loss: 2.8214 - dense_3_acc: 0.0760 - dense_3_acc_1: 0.3332 - dense_3_acc_2: 0.1510 - dense_3_acc_3: 0.0807 - dense_3_acc_4: 0.4600 - dense_3_acc_5: 0.2524 - dense_3_acc_6: 0.0399 - dense_3_acc_7: 0.5679 - dense_3_acc_8: 0.1069 - dense_3_acc_9: 0.0422
 7168/10000 [====================>.........] - ETA: 2s - loss: 21.7856 - dense_3_loss: 2.8308 - dense_3_acc: 0.0713 - dense_3_acc_1: 0.3467 - dense_3_acc_2: 0.1565 - dense_3_acc_3: 0.0822 - dense_3_acc_4: 0.4824 - dense_3_acc_5: 0.2467 - dense_3_acc_6: 0.0361 - dense_3_acc_7: 0.6081 - dense_3_acc_8: 0.0975 - dense_3_acc_9: 0.0389
 8192/10000 [=======================>......] - ETA: 1s - loss: 21.6628 - dense_3_loss: 2.8321 - dense_3_acc: 0.0820 - dense_3_acc_1: 0.3550 - dense_3_acc_2: 0.1622 - dense_3_acc_3: 0.0801 - dense_3_acc_4: 0.5184 - dense_3_acc_5: 0.2314 - dense_3_acc_6: 0.0326 - dense_3_acc_7: 0.6469 - dense_3_acc_8: 0.0887 - dense_3_acc_9: 0.0354
 9216/10000 [==========================>...] - ETA: 0s - loss: 21.5547 - dense_3_loss: 2.8269 - dense_3_acc: 0.1066 - dense_3_acc_1: 0.3596 - dense_3_acc_2: 0.1673 - dense_3_acc_3: 0.0814 - dense_3_acc_4: 0.5276 - dense_3_acc_5: 0.2266 - dense_3_acc_6: 0.0335 - dense_3_acc_7: 0.6656 - dense_3_acc_8: 0.0863 - dense_3_acc_9: 0.0352
10000/10000 [==============================] - 8s 769us/step - loss: 21.4706 - dense_3_loss: 2.8235 - dense_3_acc: 0.1183 - dense_3_acc_1: 0.3710 - dense_3_acc_2: 0.1739 - dense_3_acc_3: 0.0880 - dense_3_acc_4: 0.5086 - dense_3_acc_5: 0.2389 - dense_3_acc_6: 0.0361 - dense_3_acc_7: 0.6608 - dense_3_acc_8: 0.0906 - dense_3_acc_9: 0.0373


......


Epoch 150/150

 1024/10000 [==>...........................] - ETA: 0s - loss: 0.3782 - dense_3_loss: 0.1019 - dense_3_acc: 0.9990 - dense_3_acc_1: 1.0000 - dense_3_acc_2: 0.9990 - dense_3_acc_3: 0.9961 - dense_3_acc_4: 1.0000 - dense_3_acc_5: 0.9912 - dense_3_acc_6: 0.9688 - dense_3_acc_7: 1.0000 - dense_3_acc_8: 0.9902 - dense_3_acc_9: 0.9805
 2048/10000 [=====>........................] - ETA: 0s - loss: 0.3740 - dense_3_loss: 0.0982 - dense_3_acc: 0.9995 - dense_3_acc_1: 1.0000 - dense_3_acc_2: 0.9966 - dense_3_acc_3: 0.9971 - dense_3_acc_4: 1.0000 - dense_3_acc_5: 0.9907 - dense_3_acc_6: 0.9702 - dense_3_acc_7: 1.0000 - dense_3_acc_8: 0.9927 - dense_3_acc_9: 0.9824
 3072/10000 [========>.....................] - ETA: 0s - loss: 0.3793 - dense_3_loss: 0.1045 - dense_3_acc: 0.9990 - dense_3_acc_1: 0.9997 - dense_3_acc_2: 0.9958 - dense_3_acc_3: 0.9977 - dense_3_acc_4: 1.0000 - dense_3_acc_5: 0.9909 - dense_3_acc_6: 0.9717 - dense_3_acc_7: 1.0000 - dense_3_acc_8: 0.9909 - dense_3_acc_9: 0.9775
 4096/10000 [===========>..................] - ETA: 0s - loss: 0.3723 - dense_3_loss: 0.1016 - dense_3_acc: 0.9993 - dense_3_acc_1: 0.9998 - dense_3_acc_2: 0.9963 - dense_3_acc_3: 0.9978 - dense_3_acc_4: 1.0000 - dense_3_acc_5: 0.9919 - dense_3_acc_6: 0.9717 - dense_3_acc_7: 1.0000 - dense_3_acc_8: 0.9919 - dense_3_acc_9: 0.9771
 5120/10000 [==============>...............] - ETA: 0s - loss: 0.3729 - dense_3_loss: 0.1022 - dense_3_acc: 0.9994 - dense_3_acc_1: 0.9998 - dense_3_acc_2: 0.9967 - dense_3_acc_3: 0.9975 - dense_3_acc_4: 1.0000 - dense_3_acc_5: 0.9924 - dense_3_acc_6: 0.9709 - dense_3_acc_7: 1.0000 - dense_3_acc_8: 0.9930 - dense_3_acc_9: 0.9766
 6144/10000 [=================>............] - ETA: 0s - loss: 0.3841 - dense_3_loss: 0.1050 - dense_3_acc: 0.9992 - dense_3_acc_1: 0.9997 - dense_3_acc_2: 0.9964 - dense_3_acc_3: 0.9967 - dense_3_acc_4: 1.0000 - dense_3_acc_5: 0.9932 - dense_3_acc_6: 0.9701 - dense_3_acc_7: 1.0000 - dense_3_acc_8: 0.9932 - dense_3_acc_9: 0.9753
 7168/10000 [====================>.........] - ETA: 0s - loss: 0.3887 - dense_3_loss: 0.1053 - dense_3_acc: 0.9993 - dense_3_acc_1: 0.9997 - dense_3_acc_2: 0.9964 - dense_3_acc_3: 0.9965 - dense_3_acc_4: 1.0000 - dense_3_acc_5: 0.9926 - dense_3_acc_6: 0.9701 - dense_3_acc_7: 1.0000 - dense_3_acc_8: 0.9925 - dense_3_acc_9: 0.9750
 8192/10000 [=======================>......] - ETA: 0s - loss: 0.3896 - dense_3_loss: 0.1053 - dense_3_acc: 0.9994 - dense_3_acc_1: 0.9998 - dense_3_acc_2: 0.9967 - dense_3_acc_3: 0.9967 - dense_3_acc_4: 1.0000 - dense_3_acc_5: 0.9927 - dense_3_acc_6: 0.9694 - dense_3_acc_7: 1.0000 - dense_3_acc_8: 0.9922 - dense_3_acc_9: 0.9752
 9216/10000 [==========================>...] - ETA: 0s - loss: 0.3870 - dense_3_loss: 0.1053 - dense_3_acc: 0.9995 - dense_3_acc_1: 0.9998 - dense_3_acc_2: 0.9966 - dense_3_acc_3: 0.9969 - dense_3_acc_4: 1.0000 - dense_3_acc_5: 0.9925 - dense_3_acc_6: 0.9696 - dense_3_acc_7: 1.0000 - dense_3_acc_8: 0.9925 - dense_3_acc_9: 0.9752C:\Users\admin\Anaconda3\envs\tf\lib\site-packages\keras\engine\topology.py:2364: 

 导入模型,测试自己的数据:

import keras
import nmt_utils
import os
import numpy as np
import matplotlib.pyplot as plt



os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # 忽略警告
dataset, human, machine, inv_machine = nmt_utils.load_dataset(10000)
tx = 30
ty = 10  # 30人类输入最大长度 10机器输出长度
timestep = tx
n_a = 32
X, Y, Xoh, Yoh = nmt_utils.preprocess_data(dataset, human, machine, tx, ty)

model=keras.models.load_model("week3_model.h5")
model.load_weights("week3_weight.h5")
#model.summary()


EXAMPLES = ['3 May 1979', '5 April 09', '21th of Aug 2016', 'Tue 10 Jul 2007', 'Saturday May 9 2018', 'Mrch 3 2001',
            'March 3rd 2001', '1 March 2001','1.25.1997','April.26.1990']

for example in EXAMPLES:
    source = nmt_utils.string_to_int(example, tx,  human)
    source = np.array(list(map(lambda x: nmt_utils.to_categorical(x, num_classes=len(human)), source)))
    source = np.expand_dims(source, axis=0)
    s0 = np.zeros((Xoh.shape[0], n_a * 2))
    c0 = np.zeros((Xoh.shape[0], n_a * 2))
    prediction = model.predict([source, s0, c0])
    prediction = np.argmax(prediction, axis=-1)
    output = [inv_machine[int(i)] for i in prediction]
    print("source:", example,"    output:", ''.join(output))

attention_map=nmt_utils.plot_attention_map(model,human,inv_machine,'3 May 1979',64)

 输出:可以看出对各种格式的输入都能很好的识别。

source: 3 May 1979     output: 1979-05-03
source: 5 April 09     output: 2009-05-05
source: 21th of Aug 2016     output: 2016-08-00
source: Tue 10 Jul 2007     output: 2007-07-00
source: Saturday May 9 2018     output: 2018-05-09
source: Mrch 3 2001     output: 2001-03-03
source: March 3rd 2001     output: 2001-03-03
source: 1 March 2001     output: 2001-03-01
source: 1.25.1997     output: 1999-10-25
source: April.26.1990     output: 1990-04-26

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值