版权声明:本文为博主原创文章,未经博主允许不得转载。
一、目录结构介绍
- checkpoints:是生成唐诗时自动创建的
- dataset:存放数据集和读取诗歌文件(poem.py:预处理古诗词)
- models:存放模型构建(model.py)这是歌词与唐诗的共用代码
- inference:存放模型训练,包括训练和生成
- 主函数,主要是命令行的参数的构建(main.py)
二、整体思路
- 输入首字母,补全整首诗
- 先统计古诗词中的词频,进行词到数字的映射。生成poems_vector(词向量),word_to_int(词数字映射关系),words(词表)
三、代码实现
【第一步】poems.py: 读取诗的数据集 / 预处理古诗词
主要有2个函数构成:
(1)process_poems:
· 读取诗歌数据集(诗歌:标题、内容)
· 排除一些不必要的数据
· 统计每个字出现的次数,获取常用字
· 将每个字映射成一个数字ID(word_int_map),从而获得诗歌矢量(poems_vector)
(2)generate_batch:每次取一个batch进行训练(这里取64),获得一个epoch内有多少个batch
· 在一个epoch内迭代,获取这个batch的所有poem中最长的poem的长度
· 填充其它短的诗,空的地方放空格对应获得index标号
import collections
import os
import sys
import numpy as np
start_token = 'G'
end_token = 'E'
def process_poems(file_name):
# 诗集
poems = []
with open(file_name, 'r', encoding='utf-8') as f:
for line in f.readlines():
try:
title, content = line.strip().split(':')
content = content.replace(' ', '')
if '_' in content or '(' in content or '(' in content or '《' in content or '[' in content or start_token in content or end_token in content:
continue
if len(content) < 5 or len(content) > 79:
continue
content = start_token + content + end_token
poems.append(content)
except ValueError as e:
pass
# 按诗的字数排序
poems = sorted(poems, key=lambda l: len(line))
# 统计每个字出现的次数
all_words = []
for poem in poems:
all_words += [word for word in poem]
# 这里根据包含了每个字对应的频率
counter = collections.Counter(all_words)
count_pairs = sorted(counter.items(), key=lambda x: -x[1])
words, _ = zip(*count_pairs)
# 取前多少个常用字
words = words[:len(words)] + (' ',)
# 每个字映射为一个数字ID
word_int_map = dict(zip(words, range(len(words))))
poems_vector = [list(map(lambda word: word_int_map.get(word, len(words)), poem)) for poem in poems]
return poems_vector, word_int_map, words
def generate_batch(batch_size, poems_vec, word_to_int):
n_chunk = len(poems_vec) // batch_size
x_batches = []
y_batches = []
for i in range(n_chunk):
start_index = i * batch_size
end_index = start_index + batch_size
batches = poems_vec[start_index: end_index]
# 找到这个batch 的所有poem 中最长的poem的长度
length = max(map(len, batches))
# 填充一个这么大小的空batch,空的地方放空格对应的index标号
x_data = np.full((batch_size, length), word_to_int[' '], np.int32)
for row in range(batch_size):
# 每一行就是一首诗,在原本的长度上把诗还原上去
x_data[row, :len(batches[row])] = batches[row]
y_data