手把手实现AI诗歌生成(AI写诗)

本文通过PyTorch实现了一个基于LSTM的AI诗歌生成模型,详细介绍了数据预处理、模型构建、训练过程以及生成效果。模型以唐诗为训练集,使用字符级别的语言模型,通过1层词嵌入层和2层LSTM预测下一个字符,利用交叉熵损失函数进行优化。生成的诗歌初步展现了古诗韵味。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

本模型采用的是字符级别的诗歌生成(pytorch)

环境:

python3.X

pytorch GPU或CPU版本都行,

另外天有点冷,建议用GPU训练,电脑绝对比暖手宝好用

目录

项目文件结构:

数据已经打包:

1、数据集处理

2、构建模型与训练模型

基于概率语言模型的模型

网络结构

网络输入输出

损失函数

3、生成诗歌

4、配置文件

5、生成效果

6、结语

参考:


项目文件结构:

data:存放预处理好的数据

model:存放训练好的模型

config.py:配置文件

dataHandler.py:数据预处理及生成词典

model.py:模型文件

train.py:训练模型

generation.py:生成诗歌

poetry.txt:全唐诗,四万多首,中华民族艺术瑰宝。

数据已经打包:

链接:https://pan.baidu.com/s/1UAJFf3kKERm_XR0qRNneig 
提取码:e3y5

1、数据集处理

以四万首唐诗的文本作为训练集

它长这样:

文本中每行是一首诗,且使用冒号分割,前面是标题,后面是正文,且诗的长度不一。

对数据的处理流程大致:

  1. 读取文本,按行切分,构成古诗列表。
  2. 将全角、半角的冒号统一替换成半角的。
  3. 按冒号切分诗的标题和内容,只保留诗的内容。
  4. 最后根据诗的内容构建词典,并将处理好的数据保

 处理后的诗歌大概长这样

 代码如下:

# dataHandler.py
import numpy as np
from config import Config


def pad_sequences(sequences,
                  maxlen=None,
                  dtype='int32',
                  padding='pre',
                  truncating='pre',
                  value=0.):
    """
    # 填充
    code from keras
    Pads each sequence to the same length (length of the longest sequence).
    If maxlen is provided, any sequence longer
    than maxlen is truncated to maxlen.
    Truncation happens off either the beginning (default) or
    the end of the sequence.
    Supports post-padding and pre-padding (default).
    Arguments:
        sequences: list of lists where each element is a sequence
        maxlen: int, maximum length
        dtype: type to cast the resulting sequence.
        padding: 'pre' or 'post', pad either before or after each sequence.
        truncating: 'pre' or 'post', remove values from sequences larger than
            maxlen either in the beginning or in the end of the sequence
        value: float, value to pad the sequences to the desired value.
    Returns:
        x: numpy array with dimensions (number_of_sequences, maxlen)
    Raises:
        ValueError: in case of invalid values for `truncating` or `padding`,
            or in case of invalid shape for a `sequences` entry.
    """
    if not hasattr(sequences, '__len__'):
        raise ValueError('`sequences` must be iterable.')
    lengths = []
    for x in sequences:
        if not hasattr(x, '__len__'):
            raise ValueError('`sequences` must be a list of iterables. '
                             'Found non-iterable: ' + str(x))
        lengths.append(len(x))

    num_samples = len(sequences)
    if maxlen is None:
        maxlen = np.max(lengths)

    # take the sample shape from the first non empty sequence
    # checking for consistency in the main loop below.
    sample_shape = tuple()
    for s in sequences:
        if len(s) > 0:  # pylint: disable=g-explicit-length-test
            sample_shape = np.asarray(s).shape[1:]
            break

    x = (np.ones((num_samples, maxlen) + sample_shape) * value).astype(dtype)
    for idx, s in enumerate(sequences):
        if not len(s):  # pylint: disable=g-explicit-length-test
            continue  # empty list/array was found
        if truncating == 'pre':
            trunc = s[-maxlen:]  # pylint: disable=invalid-unary-operand-type
        elif truncating == 'post':
            trunc = s[:maxlen]
        else:
            raise ValueError('Truncating type "%s" not understood' % truncating)

        # check `trunc` has expected shape
        trunc = np.asarray(trunc, dtype=dtype)
        if trunc.shape[1:] != sample_shape:
            raise ValueError(
                'Shape of sample %s of sequence at position %s is different from '
                'expected shape %s'
                % (trunc.shape[1:], idx, sample_shape))

        if padding == 'post':
            x[idx, :len(trunc)] = trunc
        elif padding == 'pre':
            x[idx, -len(trunc):] = trunc
        else:
            raise ValueError('Padding type "%s" not understood' % padding)
    return x

def load_poetry(poetry_file, max_gen_len):
    # 加载数据集
    with open(poetry_file, 'r', encoding='utf-8') as f:
        lines = f.readlines()
        # 将冒号统一成相同格式
        lines = [line.replace(':', ':') for line in lines]
    # 数据集列表
    poetry = []
    # 逐行处理读取到的数据
    for line in lines:
        # 有且只能有一个冒号用来分割标题
        if line.count(':') != 1:
            continue
        # 
智能网联汽车的安全员高级考试涉及多个方面的专业知识,包括但不限于自动驾驶技术原理、车辆传感器融合、网络安全防护以及法律法规等内容。以下是针对该主题的一些核心知识解析: ### 关于智能网联车安全员高级考试的核心内容 #### 1. 自动驾驶分级标准 国际自动机工程师学会(SAE International)定义了六个级别的自动驾驶等级,从L0到L5[^1]。其中,L3及以上级别需要安全员具备更高的应急处理能力。 #### 2. 车辆感知系统的组成与功能 智能网联车通常配备多种传感器,如激光雷达、毫米波雷达、摄像头和超声波传感器等。这些设备协同工作以实现环境感知、障碍物检测等功能[^2]。 #### 3. 数据通信与网络安全 智能网联车依赖V2X(Vehicle-to-Everything)技术进行数据交换,在此过程中需防范潜在的网络攻击风险,例如中间人攻击或恶意软件入侵[^3]。 #### 4. 法律法规要求 不同国家和地区对于无人驾驶测试及运营有着严格的规定,考生应熟悉当地交通法典中有关自动化驾驶部分的具体条款[^4]。 ```python # 示例代码:模拟简单决策逻辑 def decide_action(sensor_data): if sensor_data['obstacle'] and not sensor_data['emergency']: return 'slow_down' elif sensor_data['pedestrian_crossing']: return 'stop_and_yield' else: return 'continue_driving' example_input = {'obstacle': True, 'emergency': False, 'pedestrian_crossing': False} action = decide_action(example_input) print(f"Action to take: {action}") ``` 需要注意的是,“同学”作为特定平台上的学习资源名称,并不提供官方认证的标准答案集;建议通过正规渠道获取教材并参加培训课程来准备此类资格认证考试
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值