视频源
《PyTorch深度学习实践》13.循环神经网络(高级篇)
课件下载 提取码 cxe4
practice
Name Classfication
根据名字的拼写进行名字所属国家的分类
传统自然语言处理,字/词one-hot编码->嵌入低维度(embedding)->RNN Cell->Linear(统一维度) ->output
而回到当前问题,由于名字分类并不需要最后一层的输出,故问题可以简化为(机器只需要从头到尾看一遍名字即可)
由于RNN容易造成梯度消失/梯度爆炸等问题,而LSTM计算量又偏大,故采用折中的GRU建模如下
Preparing Data
转成ASCII码之后可以转成one-hot编码,之后进行padding统一长度(方便构成张量(tensor))
国家转分类索引
模型选取 双向RNN/LSTM/GRU
code
import csv
import gzip
import math
from datetime import time
import matplotlib.pyplot as plt
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
HIDDEN_SIZE = 100
BATCH_SIZE = 256
N_LAYER = 2
N_EPOCHS = 100
N_CHARS = 128
USE_GPU = False
class NameDataset(Dataset):
def __init__(self,is_train_set=True):
filename = 'data/names_train.csv.gz' if is_train_set else 'data/names_test.csv.gz'
with gzip.open(filename,'rt') as f:
reader = csv.reader(f)
rows = list(reader)
self.names = [row[0] for row in rows]
self.len = len(self.names)
self.countries = [row[1] for row in rows]
self.country_list = list(sorted(set(self.countries)))
self.country_dict = self.getCountryDict()
self.country_num = len(self.country_list)
def getCountryDict(self):
country_dict = dict()
for idx