tensorflow之用LSTM实现hihell->ihello

LSTM实现hihell->ihello

//hihell->ihello
import tensorflow as tf
import numpy as np

#data
char_x = 'hihell'
char_y = 'ihello'
char = char_x+char_y
#Eliminate duplicate values
char = list(set(char))
#word to id
char_set = {w:i for i,w in enumerate(char)}
x_data = [[char_set[c] for c in char_x]]
y_data = [[char_set[c] for c in char_y]]

#Defining a hyperparameter
inputs = outputs = len(char)
sequence_length = 6
batch_size = 1
units = 8

#LSTM model
def LSTM(x_ont_hot):
    cell = tf.nn.rnn_cell.LSTMCell(num_units=units)
    output,state = tf.nn.dynamic_rnn(cell,x_ont_hot,dtype=tf.float32)
    output = tf.reshape(output,[-1,units])

    logits = tf.layers.dense(output,outputs,activation=None)
    logits = tf.reshape(logits,[batch_size,sequence_length,outputs])
    return logits

#training model
def trainModel():
    #Define placeholders
    x = tf.placeholder(tf.int32,[None,sequence_length])
    y = tf.placeholder(tf.int32,[None,sequence_length])

    #One-Hot Encoding
    x_one_hot = tf.one_hot(x,outputs)
    logits = LSTM(x_one_hot)
    #The initial weights
    weights = tf.ones([batch_size,sequence_length])
    #to count the cost
    sequence_loss = tf.contrib.seq2seq.sequence_loss(logits=logits,targets=y,weights=weights)
    loss = tf.reduce_mean(sequence_loss)
    #Defining the optimizer
    train = tf.train.AdamOptimizer(0.1).minimize(loss)
    #predict the outcome
    predict = tf.argmax(logits,axis=2)
    #Start Conversation Training
    with tf.Session() as sess:
        #Global variable initialization
        sess.run(tf.global_variables_initializer())
        for i in range(50):
            # iterative loop
            loss_val,_ = sess.run([loss,train],feed_dict={x:x_data,y:y_data})
            #Predict the output
            pred = sess.run(predict,feed_dict={x:x_data})
            pre = [char[c] for c in np.squeeze(pred)]
            print('第{:d}次的损失为:{:4f}预测结果为:{}'.format(i+1,loss_val,''.join(pre)))
#Main program entry
if __name__ == '__main__':
    trainModel()

学习讨论群请加QQ群:521364338,扫码进群领取人工智能学习资料,转载标明出处,侵权必究!

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

XiaoChao_AI

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值