PyTorch示例——使用Transformer写古诗
1. 前言
2. 版本信息
PyTorch: 2.1.2
Python: 3.10.13
3. 导包
import math
import numpy as np
from collections import Counter
import torch
from torch import nn
from torch. utils. data import TensorDataset
from torch. utils. data import DataLoader
import tqdm
import random
import sys
print ( "Pytorch 版本:" , torch. __version__)
print ( "Python 版本:" , sys. version)
Pytorch 版本: 2.1.2
Python 版本: 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:36:39) [GCC 12.3.0]
4. 数据与预处理
数据下载
先看一下原始数据
DATA_PATH = '/kaggle/input/poetry/poetry.txt'
with open ( DATA_PATH, 'r' , encoding= 'utf-8' ) as f:
lines = f. readlines( )
for i in range ( 0 , 5 ) :
print ( lines[ i] )
print ( f"origin_line_count = {
len ( lines) } " )
首春:寒随穷律变,春逐鸟声开。初风飘带柳,晚雪间花梅。碧林青旧竹,绿沼翠新苔。芝田初雁去,绮树巧莺来。
初晴落景:晚霞聊自怡,初晴弥可喜。日晃百花色,风动千林翠。池鱼跃不同,园鸟声还异。寄言博通者,知予物外志。
初夏:一朝春夏改,隔夜鸟花迁。阴阳深浅叶,晓夕重轻烟。哢莺犹响殿,横丝正网天。珮高兰影接,绶细草纹连。碧鳞惊棹侧,玄燕舞檐前。何必汾阳处,始复有山泉。
度秋:夏律昨留灰,秋箭今移晷。峨嵋岫初出,洞庭波渐起。桂白发幽岩,菊黄开灞涘。运流方可叹,含毫属微理。
仪鸾殿早秋:寒惊蓟门叶,秋发小山枝。松阴背日转,竹影避风移。提壶菊花岸,高兴芙蓉池。欲知凉气早,巢空燕不窥。
origin_line_count = 43030
开始处理数据,过滤掉异常数据
MAX_LEN = 64
MIN_LEN = 5
DISALLOWED_WORDS = [ '(' , ')' , '(' , ')' , '__' , '《' , '》' , '【' , '】' , '[' , ']' , '?' , ';' ]
poetry = [ ]
with open ( DATA_PATH, 'r' , encoding= 'utf-8' ) as f:
lines = f. readlines( )
for line in lines:
fields = line. split( ":" )
if len ( fields) != 2 :
continue
content = fields[ 1 ]
if len ( content) > MAX_LEN - 2 or len ( content) < MIN_LEN:
continue
if any ( word in content for word in DISALLOWED_WORDS) :
continue
poetry. append( content. replace( '\n' , '' ) )
for i in range ( 0 , 5 ) :
print ( poetry[ i] )
print ( f"current_line_count = {
len ( poetry) } " )
寒随穷律变,春逐鸟声开。初风飘带柳,晚雪间花梅。碧林青旧竹,绿沼翠新苔。芝田初雁去,绮树巧莺来。
晚霞聊自怡,初晴弥可喜。日晃百花色,风动千林翠。池鱼跃不同,园鸟声还异。寄言博通者,知予物外志。
夏律昨留灰,秋箭今移晷。峨嵋岫初出,洞庭波渐起。桂白发幽岩,菊黄开灞涘。运流方可叹,含毫属微理。
寒惊蓟门叶,秋发小山枝。松阴背日转,竹影避风移。提壶菊花岸,高兴芙蓉池。欲知凉气早,巢空燕不窥。
山亭秋色满,岩牖凉风度。疏兰尚染烟,残菊犹承露。古石衣新苔,新巢封古树。历览情无极,咫尺轮光暮。
current_line_count = 24375
过滤掉出现频率较低的字符串,后面统一当作 UNKNOWN
MIN_WORD_FREQUENCY = 8
counter = Counter( )
for line in poetry:
counter. update( line)
tokens = [ token for token, count in counter. items( ) if count >= MIN_WORD_FREQUENCY]
for i, ( token, count) in enumerate ( counter. items( ) ) :
print ( token, "->" , count)
if i >= 4 :
break ;
寒 -> 2612
随 -> 1036
穷 -> 482
律 -> 118
变 -> 286
定义 词典编码器 Tokenizer
class Tokenizer :
"""
词典编码器
"""
UNKNOWN = "<unknown>"
PAD = "<pad>"
BOS = "<bos>"
EOS = "<eos>"
def __init__ ( self, tokens) :
tokens = [ Tokenizer. UNKNOWN, Tokenizer. PAD, Tokenizer. BOS, Tokenizer. EOS] + tokens
self. dict_size = len