参考官网demo:lstm_text_generation
一、问题描述
本文记录使用 使用lstm 进行文本生成,数据来源为尼采著作。
注意为了使生成的文本连贯,至少需要20个纪元。因此建议在GPU上运行这个脚本,因为循环网络的计算量非常大。
如果在新数据上尝试此脚本,请确保您的语料库至少有100k个字符。 1米是更好的。
二、实现
1. 引包
关于R语言如果引包,可以参见我哦另一篇博客:R语言 —— 包(package)的下载和使用
注意引入readr 的时候可能会报错:
这是因为MRO安装程序似乎假定XQuartz安装在/opt/X11中。从其网站(https://www.xquartz.org)下载并安装XQuartz后,此警告消息消失。
library(keras)
library(readr)
library(stringr)
library(purrr)
library(tokenizers)
2. 代码及解读
library(keras)
library(readr)
library(stringr)
library(purrr)
library(tokenizers)
# Parameters --------------------------------------------------------------
# 提取的字符序列的长度
maxlen <- 40
# Data Preparation --------------------------------------------------------
# 下载语料库并将其转化为小写
# Retrieve text
path <- get_file(
'nietzsche.txt',
origin='https://s3.amazonaws.com/text-datasets/nietzsche.txt'
)
# Load, collapse, and tokenize text
text <- read_lines(path) %>%
str_to_lower() %>%
str_c(collapse = "\n") %>%
tokenize_characters(strip_non_alphanum = FALSE, simplify = TRUE)
print(sprintf("corpus length: %d", length(text)))
chars <- text %>%
unique() %>%
sort()
print(sprintf("total chars: %d", length(chars)))
'''
接下来,将提取长度为“maxlen”的部分重叠序列,对它们进行one-hot
编码并将它们打包成形状为“(sequence,maxlen,unique_characters)”
的3D Numpy数组`x`。
同时,准备一个包含相应目标的数组`y`:在每个提取序列之后的one-hot编码字符。
'''
# by=3 对每‘by’个字符序列采样一个新序列
# sentence 用于保存提取到的序列
# next_char 用于保存targets
# Cut the text in semi-redundant sequences of maxlen characters
dataset <- map(
seq(1, length(text) - maxlen - 1, by = 3),
~list(sentece = text[.x:(.x + maxlen - 1)], next_char = text[.x + maxlen])
)
dataset <- transpose(dataset)
# 接下来,将字符one-hot编码为二维数组
# Vectorization
x <- array(0, dim = c(length(dataset$sentece), maxlen, length(chars)))
y <- array(0, dim = c(length(dataset$sentece), length(chars)))
for(i in 1:length(dataset$sentece)){
x[i,,] <- sapply(chars, function(x){
as.integer(x == dataset$sentece[[i]])
})
y[i,] <- as.integer(chars == dataset$next_char[[i]])
}
# Model Definition --------------------------------------------------------
'''
构建网络
网络是一个单独的'LSTM`层,后跟一个'Dense'分类器和所有可能字符的softmax。
循环神经网络不是生成序列数据的唯一方法; 1D convnets也被证明非常成功。
'''
model <- keras_model_sequential()
model %>%
layer_lstm(128, input_shape = c(maxlen, length(chars))) %>%
layer_dense(length(chars)) %>%
layer_activation("softmax")
optimizer <- optimizer_rmsprop(lr = 0.01)
model %>% compile(
loss = "categorical_crossentropy",
optimizer = optimizer
)
# Training & Results ----------------------------------------------------
'''
训练语言模型并从中抽样
给定已训练的模型和原文本片段,重复生成新文本:
* 1)从模型中得出目前可用文本的下一个字符的概率分布
* 2)将分布重新调整到某个“temperature”
* 3)根据重新加权的分布随机抽样下一个字符
* 4)在可用文本的末尾添加新字符
这是用来重新加权模型中出现的原始概率分布的代码,并从中绘制一个字符索引(“抽样函数”):
'''
sample_mod <- function(preds, temperature = 1){
preds <- log(preds)/temperature
exp_preds <- exp(preds)
preds <- exp_preds/sum(exp(preds))
rmultinom(1, 1, preds) %>%
as.integer() %>%
which.max()
}
'''
最后,反复训练和生成文本的循环。 开始在每个epoch之后使用一系列不同
的温度生成文本。 可以看到生成的文本在模型开始收敛时如何演变,
以及温度对抽样策略的影响。
'''
on_epoch_end <- function(epoch, logs) {
cat(sprintf("epoch: %02d ---------------\n\n", epoch))
for(diversity in c(0.2, 0.5, 1, 1.2)){
cat(sprintf("diversity: %f ---------------\n\n", diversity))
start_index <- sample(1:(length(text) - maxlen), size = 1)
sentence <- text[start_index:(start_index + maxlen - 1)]
generated <- ""
for(i in 1:400){
x <- sapply(chars, function(x){
as.integer(x == sentence)
})
x <- array_reshape(x, c(1, dim(x)))
preds <- predict(model, x)
next_index <- sample_mod(preds, diversity)
next_char <- chars[next_index]
generated <- str_c(generated, next_char, collapse = "")
sentence <- c(sentence[-1], next_char)
}
cat(generated)
cat("\n\n")
}
}
print_callback <- callback_lambda(on_epoch_end = on_epoch_end)
model %>% fit(
x, y,
batch_size = 128,
epochs = 1,
callbacks = print_callback
)
'''
如上所示,低的temperature会产生极其重复且可预测的文本,但是在本地结构非常逼真的情况下:
特别是所有单词(一个单词是本地字符模式)都是真正的英语单词。随着温度的升高,生成的文本变得更有趣,令人惊讶,甚至创造性;它有时可能会发明一些听起来有些合理的新词(例如“eterned”或“troveration”)。在高温下,局部结构开始分解,大多数单词看起来像半随机字符串。毫无疑问,这里的0.5是这个特定设置中文本生成最有趣的温度。始终尝试多种采样策略!学习结构和随机性之间的巧妙平衡是让生成有趣的原因。
请注意,通过训练更大的模型,更长的时间,更多的数据,您可以获得生成的样本,这些样本看起来比我们的更连贯和更真实。但是,当然,除了随机机会之外,不要期望生成任何有意义的文本:我们所做的只是从统计模型中采样数据,其中字符来自哪些字符。语言是一种通信渠道,通信的内容与通信编码的消息的统计结构之间存在区别。为了证明这种区别,这里有一个思想实验:如果人类语言在压缩通信方面做得更好,就像我们的计算机对大多数数字通信做的那样?那么语言就没那么有意义,但它缺乏任何内在的统计结构,因此无法像我们一样学习语言模型。
拿走
*我们可以通过训练模型来生成离散序列数据,以预测给定前一个令牌的下一个令牌。
*在文本的情况下,这种模型被称为“语言模型”,可以基于单词或字符。
*采样下一个标记需要在遵守模型判断的可能性和引入随机性之间取得平衡。
*处理这个的一种方法是_softmax temperature_的概念。总是尝试不同的温度来找到“正确”的温度。
'''