pytrorch官方教程做一个聊天机器人(代码注释)

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math

USE_CUDA=torch.cuda.is_available()
device =torch.device('cuda' if USE_CUDA  else 'cpu')

#忽略警告
import warnings
warnings.filterwarnings("ignore")

"""
加载和预处理数据
下一步就是格式化处理我们的数据文件并加载到我们可以使用的结构中

Cornell Movie-Dialogs Corpus 是一个丰富的电影角色对话数据集:

10,292 对电影角色的220,579 次对话
617部电影中的9,035电影角色
总共304,713中语调
这个数据集庞大而多样,在语言形式、时间段、情感上等都有很大的变化。
我们希望这种多样性使我们的模型能够适应多种形式的输入和查询。

首先,我们通过数据文件的某些行来查看原始数据的格式
"""

corpus_name = "cornell movie-dialogs corpus"
corpus = os.path.join("data", corpus_name)

def printlines(file,n=10):
    with open(file,'rb') as datafile:
        lines=datafile.readlines()
        
    for line in lines[:10]:
        print(line)
        
"""
printlines(os.path.join(corpus,"movie_lines.txt"))
print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
printlines(os.path.join(corpus,"movie_characters_metadata.txt"))
print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
printlines(os.path.join(corpus,"movie_titles_metadata.txt"))
print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
printlines(os.path.join(corpus,"movie_conversations.txt"))
print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
printlines(os.path.join(corpus,"chameleons.pdf"))
print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')
printlines(os.path.join(corpus,"chameleons.pdf"))
print('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>')

"""
        



"""
创建格式化数据文件
为了方便起见,我们将创建一个格式良好的数据文件,其中每一行包含一个由 tab 制表符分隔的查询语句和响应语句对。

以下函数便于解析原始 movie_lines.txt 数据文件。

loadLines 将文件的每一行拆分为字段(lineID, characterID, movieID, character, text)组合的字典
loadConversations 根据 movie_conversations.txt 将 loadLines 中的每一行数据进行归类
extractSentencePairs 从对话中提取一对句子
"""

"""
格式化并且加载数据
将文件的每一行拆分为字段字典
line = {
     'L183198': {
         'lineID': 'L183198', 
         'characterID': 'u5022', 
         'movieID': 'm333', 
         'character': 'FRANKIE', 
         'text': "Well we'd sure like to help you.\n"
     }, {...}
 }
"""
def loadlines(fileName,fields):
    lines={}
    with open(fileName,'r',encoding='iso-8859-1') as f:
        for line in f:
            values=line.split(" +++$+++ ")
            lineobj={}
            for i,field in enumerate(fields):
                lineobj[field]=values[i]
                
            lines[lineobj["lineID"]]=lineobj
            
    return lines

"""
# 将 `loadLines` 中的行字段分组为基于 *movie_conversations.txt* 的对话
# [{
#     'character1ID': 'u0',
#     'character2ID': 'u2',
#     'movieID': 'm0',
#     'utteranceIDs': "['L194', 'L195', 'L196', 'L197']\n",
#     'lines': [{
#         'lineID': 'L194',
#         'characterID': 'u0',
#         'movieID': 'm0',
#         'character': 'BIANCA',
#         'text': 'Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\n'
#     }, {
#         'lineID': 'L195',
#         'characterID': 'u2',
#         'movieID': 'm0',
#         'character': 'CAMERON',
#         'text': "Well, I thought we'd start with pronunciation, if that's okay with you.\n"
#     }, {
#         'lineID': 'L196',
#         'characterID': 'u0',
#         'movieID': 'm0',
#         'character': 'BIANCA',
#         'text': 'Not the hacking and gagging and spitting part.  Please.\n'
#     }, {
#         'lineID': 'L197',
#         'characterID': 'u2',
#         'movieID': 'm0',
#         'character': 'CAMERON',
#         'text': "Okay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n"
#     }]
# }, {...}]
"""

def loadConversations(fillname,lines,fields):
    conversations=[]
    with open(fillname,'r',encoding='iso-8859-1') as f:
        for line in f:
            values=line.split("+++$+++")
            convObj={}
            for i,field in enumerate(fields):
                convObj[field]=values[i]
            
            lineIds=eval(convObj['utteranceIDs'])
            convObj["lines"]=[]
            for id in lineIds:
                convObj["lines"].append(lines[id])
            conversations.append(convObj)
            
    return conversations

# 从对话中提取一对句子

def extractSentencePairs(conversations):
    qa_pairs=[]
    for con in conversations:
        for i in range(len(con['lines'])-1):
            inputline=con['lines'][i]['text'].strip()
            outputline=con['lines'][i+1]['text'].strip()
            
            if inputline and outputline:
                qa_pairs.append([inputline,outputline])
                
    return qa_pairs


"""
现在我们将调用这些函数来创建文件,我们命名为 formatted_movie_lines.txt.
Processing corpus...

Loading conversations...

Writing newly formatted file...

Sample lines from file:
b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\n"
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\n"
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n"
b"You're asking me out.  That's so cute. What's your name again?\tForget it.\n"
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\n"
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\n"
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\tSeems like she could get a date easy enough...\n"
b'Why?\tUnsolved mystery.  She used to be really popular when she started high school, then it was just like she got sick of it or something.\n'
b"Unsolved mystery.  She used to be really popular when she started high school, then it was just like she got sick of it or something.\tThat's a shame.\n"
b'Gosh, if only we could find Kat a boyfriend...\tLet me see what I can do.\n'
"""
datafile = os.path.join(corpus, "formatted_movie_lines.txt")
delimiter = '\t'
# Unescape the delimiter
delimiter = str(codecs.decode(delimiter, "unicode_escape"))

lines={}
conversations=[]
MOVIE_LINES_FIELDS = ["lineID", "characterID", "movieID", "character", "text"]
MOVIE_CONVERSATIONS_FIELDS = ["character1ID", "character2ID", "movieID", "utteranceIDs"]

# Load lines and process conversations
print("\nProcessing corpus...")
lines=loadlines(os.path.join(corpus, "movie_lines.txt"), MOVIE_LINES_FIELDS)
print("\nLoading conversations...")
conversations=loadConversations(os.path.join(os.path.join(corpus, "movie_conversations.txt")),
                                  lines, MOVIE_CONVERSATIONS_FIELDS)
# Write new csv file
print("\nWriting newly formatted file...")
with open(datafile,'w',encoding='utf-8') as outputfile:
    writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n')
    for pair in extractSentencePairs(conversations):
        writer.writerow(pair)
        
        
# Print a sample of lines
print("\nSample lines from file:")
printlines(datafile)


"""
加载和清洗数据
我们下一个任务是创建词汇表并将查询/响应句子对(对话)加载到内存。

注意我们正在处理词序,这些词序没有映射到离散数值空间。因此,
我们必须通过数据集中的单词来创建一个索引。

为此我们创建了一个Voc类,它会存储从单词到索引的映射、索引到单词的反向映射、每个单词的计数和总单词量。
这个类提供向词汇表中添加单词的方法(addWord)、添加所有单词到句子中的方法 (addSentence) 和清洗不常见的单词方法(trim)。
更多的数据清洗在后面进行。
"""
PAD_token=0 #used for pandding short sentences
SOS_token=1 #start of sentence token
EOS_token=2 #end of sentence token

class Voc:
    def __init__(self,name):
        self.name=name
        self.trimmed=False
        self.word2index={}
        self.index2word={PAD_token:"PAD",SOS_token:"SOS",EOS_token:"EOS"}
        self.word2count={}
        self.num_words=3 #count PAD SOS EOS 
        
        
    def addSentence(self,sentence):
        for word in sentence.split(' '):
            self.addWord(word)
            
    def addWord(self,word):
        if word not in self.word2index:
            self.word2index[word]=self.num_words
            self.word2count[word]=1
            self.index2word[self.num_words]=word
            self.num_words+=1
        else:
            self.word2count[word]+=1
            
    #删除低于摸个阀值的单词
    def trim(self,min_count):
        if self.trimmed:
            retrun
        self.trimmed=True
        
        keep_words=[]
        for k,v in self.word2count.items():
            if v>=min_count:
                keep_words.append(k)
                
        print('keep_words {} / {} = {:.4f}'.format(
            len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
        ))
        
        self.word2index={}
        self.word2count={}
        self.index2word={PAD_token:"PAD",SOS_token:"SOS",EOS_token:"EOS"}
        self.num_words=3
        
        for word in keep_words:
            self.addWord(word)
            

"""
现在我们可以组装词汇表和查询/响应语句对。在使用数据之前,我们必须做一些预处理。

首先,我们必须使用unicodeToAscii将unicode字符串转换为ASCII。
然后,我们应该将所有字母转换为小写字母并清洗掉除基本标点之外的所有非字母字符 (normalizeString)。
最后,为了帮助训练收敛,我们将过滤掉长度大于MAX_LENGTH 的句子 (filterPairs)。
"""
MAX_LENGTH=10

def unicodeToAscii(s):
    return "".join(
        c for c in unicodedata.normalize('NFD',s)
        if unicodedata.category(c) != 'Mn'
    )

#初始化VOc对象 和格式化pairs对话存放在List中
def readVocs(datafile,corpus_name):
    print("reading lines ...")
    
    #read the file and  split into lines
    lines =open(datafile,encoding='utf-8').read().strip().split('\n')
    
    pairs=[[unicodeToAscii(s) for s in l.split('\t')] for l in lines]
    
    voc=Voc(corpus_name)
    
    return voc ,pairs

#如果对'p'中的2各句子都第一max_length的阀值  则返回True
def filterPair(p):
    
    return len(p[0].split(" "))<MAX_LENGTH and len(p[1].split(" "))<MAX_LENGTH



# 过滤满足条件的Pairs对话
def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]


#使用上面定义的函数,返回一个填充的voc对象和对列表
def loadPrepareData(corpus,corpus_name,datafile,save_dir):
    print("start prepareing training data ...")
    voc, pairs =readVocs(datafile,corpus_n
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值