手把手实现AI诗歌生成(AI写诗)

本模型采用的是字符级别的诗歌生成(pytorch)

环境:

python3.X

pytorch GPU或CPU版本都行,

另外天有点冷,建议用GPU训练,电脑绝对比暖手宝好用

目录

项目文件结构:

数据已经打包:

1、数据集处理

2、构建模型与训练模型

基于概率语言模型的模型

网络结构

网络输入输出

损失函数

3、生成诗歌

4、配置文件

5、生成效果

6、结语

参考:


项目文件结构:

data:存放预处理好的数据

model:存放训练好的模型

config.py:配置文件

dataHandler.py:数据预处理及生成词典

model.py:模型文件

train.py:训练模型

generation.py:生成诗歌

poetry.txt:全唐诗,四万多首,中华民族艺术瑰宝。

数据已经打包:

链接:https://pan.baidu.com/s/1UAJFf3kKERm_XR0qRNneig 
提取码:e3y5

1、数据集处理

以四万首唐诗的文本作为训练集

它长这样:

文本中每行是一首诗,且使用冒号分割,前面是标题,后面是正文,且诗的长度不一。

对数据的处理流程大致:

  1. 读取文本,按行切分,构成古诗列表。
  2. 将全角、半角的冒号统一替换成半角的。
  3. 按冒号切分诗的标题和内容,只保留诗的内容。
  4. 最后根据诗的内容构建词典,并将处理好的数据保

 处理后的诗歌大概长这样

 代码如下:

# dataHandler.py
import numpy as np
from config import Config


def pad_sequences(sequences,
                  maxlen=None,
                  dtype='int32',
                  padding='pre',
                  truncating='pre',
                  value=0.):
    """
    # 填充
    code from keras
    Pads each sequence to the same length (length of the longest sequence).
    If maxlen is provided, any sequence longer
    than maxlen is truncated to maxlen.
    Truncation happens off either the beginning (default) or
    the end of the sequence.
    Supports post-padding and pre-padding (default).
    Arguments:
        sequences: list of lists where each element is a sequence
        maxlen: int, maximum length
        dtype: type to cast the resulting sequence.
        padding: 'pre' or 'post', pad either before or after each sequence.
        truncating: 'pre' or 'post', remove values from sequences larger than
            maxlen either in the beginning or in the end of the sequence
        value: float, value to pad the sequences to the desired value.
    Returns:
        x: numpy array with dimensions (number_of_sequences, maxlen)
    Raises:
        ValueError: in case of invalid values for `truncating` or `padding`,
            or in case of invalid shape for a `sequences` entry.
    """
    if not hasattr(sequences, '__len__'):
        raise ValueError('`sequences` must be iterable.')
    lengths = []
    for x in sequences:
        if not hasattr(x, '__len__'):
            raise ValueError('`sequences` must be a list of iterables. '
                             'Found non-iterable: ' + str(x))
        lengths.append(len(x))

    num_samples = len(sequences)
    if maxlen is None:
        maxlen = np.max(lengths)

    # take the sample shape from the first non empty sequence
    # checking for consistency in the main loop below.
    sample_shape = tuple()
    for s in sequences:
        if len(s) > 0:  # pylint: disable=g-explicit-length-test
            sample_shape = np.asarray(s).shape[1:]
            break

    x = (np.ones((num_samples, maxlen) + sample_shape) * value).astype(dtype)
    for idx, s in enumerate(sequences):
        if not len(s):  # pylint: disable=g-explicit-length-test
            continue  # empty list/array was found
        if truncating == 'pre':
            trunc = s[-maxlen:]  # pylint: disable=invalid-unary-operand-type
        elif truncating == 'post':
            trunc = s[:maxlen]
        else:
            raise ValueError('Truncating type "%s" not understood' % truncating)

        # check `trunc` has expected shape
        trunc = np.asarray(trunc, dtype=dtype)
        if trunc.shape[1:] != sample_shape:
            raise ValueError(
                'Shape of sample %s of sequence at position %s is different from '
                'expected shape %s'
                % (trunc.shape[1:], idx, sample_shape))

        if padding == 'post':
            x[idx, :len(trunc)] = trunc
        elif padding == 'pre':
            x[idx, -len(trunc):] = trunc
        else:
            raise ValueError('Padding type "%s" not understood' % padding)
    return x

def load_poetry(poetry_file, max_gen_len):
    # 加载数据集
    with open(poetry_file, 'r', encoding='utf-8') as f:
        lines = f.readl
  • 7
    点赞
  • 67
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值