import collections
import numpy as np
start_token = 'B'#begin
end_token = 'E'#end
def process_poems(file_name):
# poems -> list of numbers诗集
poems = []#是二维数组,但不是矩阵,因为每首诗的长度不同
with open(file_name, "r", encoding='utf-8', ) as f:
for line in f.readlines():
title, content = line.strip().split(':')#以冒号来分开
content = content.replace(' ', '')#去掉所有的空格
if '_' in content or '(' in content or '(' in content or '《' in content or '[' in content or \
start_token in content or end_token in content:#去掉特殊符号
if len(content) < 5 or len(content) > 79:#内容少于5个字或大于79个字为异常诗需要剔除,跳出本次循环
content = start_token + content + end_token
except ValueError as e:
all_words = [word for poem in poems for word in poem]
counter = collections.Counter(all_words)
count_pairs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
words, _ = zip(*count_pairs)
#末尾加一个空格,('不','的', , ,' ')
words = words + (' ',)
#为每个字打上位置标签,从0开始,形成字典,高频次的字在前面{'不':0,'的':1, , ,' ':6110}
word_int_map = dict(zip(words, range(len(words))))
poems_vector = [list(map(lambda word: word_int_map.get(word, len(words)), poem)) for poem in poems]
return poems_vector, word_int_map, words
def generate_batch(batch_size, poems_vec, word_to_int):
n_chunk = len(poems_vec) // batch_size
x_batches = []
y_batches = []
for i in range(n_chunk):
start_index = i * batch_size
end_index = start_index + batch_size
batches = poems_vec[start_index:end_index]
length = max(map(len, batches))
x_data = np.full((batch_size, length), word_to_int[' '], np.int32)
最新推荐文章于 2024-09-19 19:05:43 发布