from __future__ import unicode_literals, print_function, division
from io importopenimport glob
import os
import unicodedata
import string
all_letters = string.ascii_letters +" .,;'-"
n_letters =len(all_letters)+1# Plus EOS markerdeffindFiles(path):return glob.glob(path)# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427defunicodeToAscii(s):return''.join(
c for c in unicodedata.normalize('NFD', s)if unicodedata.category(c)!='Mn'and c in all_letters
)# Read a file and split into linesdefreadLines(filename):
lines =open(filename, encoding='utf-8').read().strip().split('\n')return[unicodeToAscii(line)for line in lines]# Build the category_lines dictionary, a list of lines per category
category_lines ={}
all_categories =[]for filename in findFiles('data/names/*.txt'):
category = os.path.splitext(os.path.basename(filename))[0]
all_categories.append(category)
lines = readLines(filename)
category_lines[category]= lines
n_categories =len(all_categories)if n_categories ==0:raise RuntimeError('Data not found. Make sure that you downloaded data ''from https://download.pytorch.org/tutorial/data.zip and extract it to ''the current directory.')print('# categories:', n_categories, all_categories)print(unicodeToAscii("O'Néàl"))
数据准备
import random
# 随机获取到一个分类defrandomChoice(l):print('len',len(l))return l[random.randint(0,len(l)-1)]# Get a random category and random line from that category#在这个分类的数据中随机抽取一些样本defrandomTrainingPair():
category = randomChoice(all_categories)
line = randomChoice(category_lines[category])return category, line
# One-hot vector for category#这个分类的数据的one-hot矩阵defcategoryTensor(category):
li = all_categories.index(category)
tensor = torch.zeros(1, n_categories)
tensor[0][li]=1return tensor
# One-hot matrix of first to last letters (not including EOS) for input#根据采样得到的数据构建onehot矩阵definputTensor(line):
tensor = torch.zeros(len(line),1, n_letters)for li inrange(len(line)):
letter = line[li]
tensor[li][0][all_letters.find(letter)]=1return tensor
# LongTensor of second letter to end (EOS) for targetdeftargetTensor(line):
letter_indexes =[all_letters.find(line[li])for li inrange(1,len(line))]
letter_indexes.append(n_letters -1)# EOSreturn torch.LongTensor(letter_indexes)# Make category, input, and target tensors from a random category, line pairdefrandomTrainingExample():
category, line = randomTrainingPair()
category_tensor = categoryTensor(category)
input_line_tensor = inputTensor(line)
target_line_tensor = targetTensor(line)return category_tensor, input_line_tensor, target_line_tensor
randomTrainingExample()