【导读】本文翻译自TensorFlow官网新出的教程《Text generation using a RNN with eager execution》,该教程介绍如何使用TensorFlow Eager(动态图)和RNN来学习生成莎士比亚的作品。模型可以根据已有的字符序列来预测序列的下一个字符,以达到文本生成的效果。
简介
教程包含了用Tensorflow Eager(动态图)和tf.keras实现的可执行代码,下面是代码运行的示例结果:
QUEENE:
I had thought thou hadst a Roman; for the oracle,
Thus by All bids the man against the word,
Which are so weak of care, by old care done;
Your children were in your holy love,
And the precipitation through the bleeding throne.
BISHOP OF ELY:
Marry, and will, my lord, to weep in such a one were prettiest;
Yet now I was adopted heir
Of the world's lamentable day,
To watch the next way with his father with his face?
ESCALUS:
The cause why then we are all resolved more sons.
VOLUMNIA:
O, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, it is no sin it should be dead,
And love and pale as any will to that word.
QUEEN ELIZABETH:
But how long have I heard the soul for this world,
And show his hands of life be proved to stand.
PETRUCHIO:
I say he look'd on, if I must be content
To stay him from the fatal of our country's bliss.
His lordship pluck'd from this sentence then for prey,
And then let us twain, being the moon,
were she such a case as fills m
虽然生成的句子中有一部分看起来比较符合语法,大多数生成的句子还是没有什么意义的。这个模型并没有考虑词的意义,而是考虑:
该模型是基于字符的,模型并不知道如何用字符拼写单词,甚至不知道词是文本的组成单元。
文本的结构很像戏剧,和训练集类似,文本块往往以说话者的大写名字开始。
模型中序列长度为100的序列上进行训练,但它有能力去生成更长的序列。
简介
导入相关库:
importtensorflow astf
tf.enable_eager_execution()
importnumpy asnp
importos
importtime
下载数据集:
path_to_file = tf.keras.utils.get_file('shakespeare.txt','https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')
读取数据:
text = open(path_to_file).read()print('Length of text: {} characters'.format(len(text)))
查看数据:
print(text[:1000])
First Citizen:
Before we proceed anyfurther,hear me speak.
All:
Speak,speak.
First Citizen:
You are allresolved rather to die than to famish?
All:
Resolved. resolved.
First Citizen:
First,you know Caius Marcius ischief enemy to the people.
All:
We know't, we know't.
First Citizen:
Let us kill him, andwe'll have corn at our own price.Is't a verdict?All:
No more talking on't; let it be done: away, away!Second Citizen:
One word,good citizens.
First Citizen:
We are accounted poor citizens,the patricians good.
What authority surfeits on would relieve us: ifthey
would yieldus but the superfluity, whileit were
wholesome,we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us,the objectof our misery, is asan
inventory to particularise their abundance; our
sufferance isa gain to them Let us revenge this withour pikes,ere we become rakes: forthe gods know I
speak this inhunger forbread, not inthirst forrevenge.
构建字符集合:
# The unique characters in the filevocab = sorted(set(text))
print ('{} unique characters'.format(len(vocab)))
65unique characters
文本处理
构建字符索引:
char2idx = {u:i fori,u inenumerate(vocab)}
idx2char = np.array(vocab)
text_as_int = np.array([char2idx[c] forc intext])
forchar,_ inzip(char2idx,range(20)):
print('{:6s} ---> {:4d}'.format(repr(char),char2idx[char]))
'j'---> 48'f'---> 44'R'---> 30':'---> 10'W'---> 35';'---> 11'o'---> 53'b'---> 40'K'---> 23'L'---> 24'O'---> 27'h'---> 46'm'---> 51'u'---> 59'H'---> 20'z'---> 64'!'---> 2'S'---> 31'N'---> 26'Z'---> 38
预测任务:
给定一个字符或一个字符序列,希望预测下一个最可能出现的字符。在训练时,我们的模型输入seq_length个字符,输出seq_length个字符。例如seq_length为4,我们的文本你是Hello,那么输入是Hell,输出是ello。
首先获得多个长度为seq_length的文本:
seq_length = 100chunks = tf.data.Dataset.from_tensor_slices(text_as_int).batch(seq_length+1,drop_remainder=True)
foritem inchunks.take(5):
print(repr(''.join(idx2char[item.numpy()])))
'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou ''are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you k'
"now Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us ki"
"ll him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be d"
'one: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor citi'
从文本中提取输入和目标:
defsplit_input_target(chunk):
input_text = chunk[:-1]
target_text = chunk[1:]
returninput_text,target_text
dataset = chunks.map(split_input_target)
forinput_example,target_example indataset.take(1):
print('Input data: ',repr(''.join(idx2char[input_example.numpy()])))
print('Target data:',repr(''.join(idx2char[target_example.numpy()])))
Input data: 'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou'Target data: 'irst Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou '
利用tf.data来划分batch,并进行shuffle:
# batch大小BATCH_SIZE = 64# shuffle缓存大小BUFFER_SIZE = 10000dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE,drop_remainder=True)
模型
模型类,使用了tf.keras的Embedding和GRU层:
classModel(tf.keras.Model):
def__init__(self,vocab_size,embedding_dim,units):
super(Model,self).__init__()
self.units = units
self.embedding = tf.keras.layers.Embedding(vocab_size,embedding_dim)
iftf.test.is_gpu_available():
self.gru = tf.keras.layers.CuDNNGRU(self.units,return_sequences=True,recurrent_initializer='glorot_uniform',stateful=True)
else:
self.gru = tf.keras.layers.GRU(self.units,return_sequences=True,recurrent_activation='sigmoid',recurrent_initializer='glorot_uniform',stateful=True)
self.fc = tf.keras.layers.Dense(vocab_size)
defcall(self,x):
embedding = self.embedding(x)
# output at every time step
# output shape == (batch_size, seq_length, hidden_size)output = self.gru(embedding)
# The dense layer will output predictions for every time_steps(seq_length)
# output shape after the dense layer == (seq_length * batch_size, vocab_size)prediction = self.fc(output)
# states will be used to pass at every step to the model while trainingreturnprediction
实例化模型、优化器和损失函数:
# 字典大小vocab_size = len(vocab)
# 字符向量维度embedding_dim = 256# RNN单元维度units = 1024model = Model(vocab_size,embedding_dim,units)
optimizer = tf.train.AdamOptimizer()
defloss_function(real,preds):
returntf.losses.sparse_softmax_cross_entropy(labels=real,logits=preds)
训练模型:
model.build(tf.TensorShape([BATCH_SIZE,seq_length]))
model.summary()
_________________________________________________________________
Layer (type) Output Shape Param #=================================================================
embedding (Embedding) multiple 16640_________________________________________________________________
gru (GRU) multiple 3935232_________________________________________________________________
dense (Dense) multiple 66625=================================================================
Total params: 4,018,497Trainable params: 4,018,497Non-trainable params: 0_________________________________________________________________
# 保存checkpoints的目录checkpoint_dir = './training_checkpoints'# Checkpoint文件名checkpoint_prefix = os.path.join(checkpoint_dir,"ckpt")
简单训练几次:
EPOCHS = 5
# 循环训练forepoch inrange(EPOCHS):
start = time.time()
# 在每轮epoch开始时,初始化状态hidden = model.reset_states()
for(batch,(inp,target)) inenumerate(dataset):
withtf.GradientTape() astape:predictions = model(inp)
loss = loss_function(target,predictions)
grads = tape.gradient(loss,model.variables)
optimizer.apply_gradients(zip(grads,model.variables))
ifbatch % 100== 0:
print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,batch,loss))
# 每5个epoch保存一下模型if(epoch + 1) % 5== 0:
model.save_weights(checkpoint_prefix)
print('Epoch {} Loss {:.4f}'.format(epoch + 1,loss))
print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))
Epoch 3Batch 0Loss 1.7645Epoch 3Batch 100Loss 1.6853Epoch 3Loss 1.6164Time taken for1epoch 610.0756878852844sec
Epoch 4Batch 0Loss 1.6491Epoch 4Batch 100Loss 1.5350Epoch 4Loss 1.5071Time taken for1epoch 609.8330454826355sec
Epoch 5Batch 0Loss 1.4715Epoch 5Batch 100Loss 1.4685Epoch 5Loss 1.4042Time taken for1epoch 608.6753587722778sec
保存模型:
model.save_weights(checkpoint_prefix)
模型载入:
如果需要载入保存的模型,可以使用下面的代码:
model = Model(vocab_size,embedding_dim,units)
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
model.build(tf.TensorShape([1, None]))
使用训练的模型来进行文本生成:
# 生成文本
# 生成字符的数量num_generate = 1000# 起始字符start_string = 'Q'input_eval = [char2idx[s] fors instart_string]
input_eval = tf.expand_dims(input_eval,0)
# 用于存储结果的空字符串text_generated = []
# 较低的temperature值会产出一些预料之中的文本.
# 较高的temperature值会产出一些出乎预料的文本.temperature = 1.0
# 这里batch大小为1model.reset_states()
fori inrange(num_generate):
predictions = model(input_eval)
# 移除batch维predictions = tf.squeeze(predictions,0)
# 用多项式分布来预测生成的词predictions = predictions / temperature
predicted_id = tf.multinomial(predictions,num_samples=1)[-1,0].numpy()
# 将生成的词和上一次的隐藏状态作为模型的下一个输入input_eval = tf.expand_dims([predicted_id],0)
text_generated.append(idx2char[predicted_id])
print(start_string + ''.join(text_generated))
运行结果:
QULERBY:
If a body.
But I would me your lood.
Steak ungrace and asthis only inthe ploaduse,his they,much you amed on't.RSCALIO:
Hearn' thousand as your well, and obepional.ANTONIO:
Can wathach this wam a discure that braichal heep itspose,Teparmate confoim it: never knor sheep,so litter
Plarence? He,But thou sunds a parmon servection:
Occh Rom o'ld him sir;madish yim,I'll surm let as hand upherityShepherd:
Why do I sering their stumble; the thank emo'st yiedBaunted unpluction; the main,sir,What's a meanulainstEven worship tebomn slatued of his name,Manisholed shorks you go?
BUCKINGHAM:
We look thus then impare'd least itsiby drumes,That I,what!
Nurset,fell beshee that which I will
to the near-Volshing upon this aguin against fless
Is done untlein with isthe neck,Thands he shall fear'ds; let me love at officed:Where elseto her awticions, asyou hall,my lord.
KING RICHARD II:
I will been another one our accuser less
Tiold,methought to the presench of consiar
参考资料:
https://www.tensorflow.org/tutorials/sequences/text_generation
https://github.com/tensorflow/docs/blob/master/site/en/tutorials/sequences/text_generation.ipynb
-END-
专 · 知
人工智能领域26个主题知识资料全集获取与加入专知人工智能服务群: 欢迎微信扫一扫加入专知人工智能知识星球群,获取专业知识教程视频资料和与专家交流咨询!
请PC登录www.zhuanzhi.ai或者点击阅读原文,注册登录专知,获取更多AI知识资料!
请加专知小助手微信(扫一扫如下二维码添加),加入专知主题群(请备注主题类型:AI、NLP、CV、 KG等)交流~
AI 项目技术 & 商务合作:bd@zhuanzhi.ai, 或扫描上面二维码联系!
请关注专知公众号,获取人工智能的专业知识!
点击“阅读原文”,使用专知