# transformer-xl（片段级递归机制+相对位置编码）（wikitext-103 语言模型）

12 篇文章 0 订阅

wikitext-103

## 三、数据处理(data_utils.py)(vocabulary.py)

data_utils.py

import os, sys
import glob

from collections import Counter, OrderedDict
import numpy as np
import torch

from utils.vocabulary import Vocab
# 生成bach
# 函数调用：data_iter = LMOrderedIterator(self.train, *args, **kwargs)
class LMOrderedIterator(object):
def __init__(self, data, bsz, bptt, device='cpu', ext_len=None):
"""
data -- LongTensor -- the LongTensor is strictly ordered
"""
# bsz=60
self.bsz = bsz
# bptt=150
self.bptt = bptt
# ext_len=0
self.ext_len = ext_len if ext_len is not None else 0
self.device = device
# n_step=1720450 表示需要多少步操作1亿个token向量除以bsz
# Work out how cleanly we can divide the dataset into bsz parts.
self.n_step = data.size(0) // bsz
# 修剪掉多余部分（剩余部分）
# Trim off any extra elements that wouldn't cleanly fit (remainders).
data = data.narrow(0, 0, self.n_step * bsz)
#在bsz批处理中均匀分配数据 [1720450,60]
# Evenly divide the data across the bsz batches.
self.data = data.view(bsz, -1).t().contiguous().to(device)

# Number of mini-batches
self.n_batch = (self.n_step + self.bptt - 1) // self.bptt

def get_batch(self, i, bptt=None):
if bptt is None: bptt = self.bptt
# seq_len=152/148
seq_len = min(bptt, self.data.size(0) - 1 - i)
# end_idx=152/300
end_idx = i + seq_len
# beg_idx=0/152
beg_idx = max(0, i - self.ext_len)
# data = [0:152][60] [152:300][60]
data = self.data[beg_idx:end_idx]
# target = [1:153][60] [153:301][60]
target = self.data[i+1:i+1+seq_len]

return data, target, seq_len

def get_fixlen_iter(self, start=0):
for i in range(start, self.data.size(0) - 1, self.bptt):
yield self.get_batch(i)

def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3):
# max_len = 150+3*5=165
max_len = self.bptt + max_deviation * std
i = start
while True:
# bptt=150
bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.
# bptt=152/148/146
bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std))))
data, target, seq_len = self.get_batch(i, bptt)
i += seq_len
yield data, target, seq_len
# if i >= self.data.size(0) - 2:
if i >= 1000:

break

def __iter__(self):
return self.get_fixlen_iter()

class LMShuffledIterator(object):
def __init__(self, data, bsz, bptt, device='cpu', ext_len=None, shuffle=False):
"""
data -- list[LongTensor] -- there is no order among the LongTensors
"""
self.data = data

self.bsz = bsz
self.bptt = bptt
self.ext_len = ext_len if ext_len is not None else 0

self.device = device
self.shuffle = shuffle

def get_sent_stream(self):
# index iterator
epoch_indices = np.random.permutation(len(self.data)) if self.shuffle \
else np.array(range(len(self.data)))

# sentence iterator
for idx in epoch_indices:
yield self.data[idx]

def stream_iterator(self, sent_stream):
# streams for each data in the batch
streams = [None] * self.bsz

data = torch.LongTensor(self.bptt, self.bsz)
target = torch.LongTensor(self.bptt, self.bsz)

n_retain = 0

while True:
# data   : [n_retain+bptt x bsz]
# target : [bptt x bsz]
data[n_retain:].fill_(-1)
target.fill_(-1)

valid_batch = True

for i in range(self.bsz):
n_filled = 0
try:
while n_filled < self.bptt:
if streams[i] is None or len(streams[i]) <= 1:
streams[i] = next(sent_stream)
# number of new tokens to fill in
n_new = min(len(streams[i]) - 1, self.bptt - n_filled)
# first n_retain tokens are retained from last batch
data[n_retain+n_filled:n_retain+n_filled+n_new, i] = \
streams[i][:n_new]
target[n_filled:n_filled+n_new, i] = \
streams[i][1:n_new+1]
streams[i] = streams[i][n_new:]
n_filled += n_new
except StopIteration:
valid_batch = False
break

if not valid_batch:
return

data = data.to(self.device)
target = target.to(self.device)

yield data, target, self.bptt

n_retain = min(data.size(0), self.ext_len)
if n_retain > 0:
data[:n_retain] = data[-n_retain:]
data.resize_(n_retain + self.bptt, data.size(1))

def __iter__(self):
# sent_stream is an iterator
sent_stream = self.get_sent_stream()

for batch in self.stream_iterator(sent_stream):
yield batch

class LMMultiFileIterator(LMShuffledIterator):
def __init__(self, paths, vocab, bsz, bptt, device='cpu', ext_len=None,
shuffle=False):

self.paths = paths
self.vocab = vocab

self.bsz = bsz
self.bptt = bptt
self.ext_len = ext_len if ext_len is not None else 0

self.device = device
self.shuffle = shuffle

def get_sent_stream(self, path):
if self.shuffle:
np.random.shuffle(sents)
sent_stream = iter(sents)

return sent_stream

def __iter__(self):
if self.shuffle:
np.random.shuffle(self.paths)

for path in self.paths:
# sent_stream is an iterator
sent_stream = self.get_sent_stream(path)
for batch in self.stream_iterator(sent_stream):
yield batch
# 数据处理：1.读取语料 2.构建词表 3.将token转换成index
# 函数调用： corpus = Corpus(datadir, dataset, **kwargs)
class Corpus(object):
def __init__(self, path, dataset, *args, **kwargs):
self.dataset = dataset
self.vocab = Vocab(*args, **kwargs)# *args表示它是一个可变数组

if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
self.vocab.count_file(os.path.join(path, 'train.txt'))
self.vocab.count_file(os.path.join(path, 'valid.txt'))
self.vocab.count_file(os.path.join(path, 'test.txt'))
elif self.dataset == 'wt103':
# 读取语料，转成token
self.vocab.count_file(os.path.join(path, 'train.txt'))
elif self.dataset == 'lm1b':
train_path_pattern = os.path.join(
path, '1-billion-word-language-modeling-benchmark-r13output',
'training-monolingual.tokenized.shuffled', 'news.en-*')
train_paths = glob.glob(train_path_pattern)
# the vocab will load from file when build_vocab() is called
# 构建词表
self.vocab.build_vocab()

if self.dataset in ['ptb', 'wt2', 'wt103']:
# 将token转换成index
self.train = self.vocab.encode_file(
os.path.join(path, 'train.txt'), ordered=True)
self.valid = self.vocab.encode_file(
os.path.join(path, 'valid.txt'), ordered=True)
self.test  = self.vocab.encode_file(
os.path.join(path, 'test.txt'), ordered=True)
elif self.dataset in ['enwik8', 'text8']:
self.train = self.vocab.encode_file(
self.valid = self.vocab.encode_file(
self.test  = self.vocab.encode_file(
elif self.dataset == 'lm1b':
self.train = train_paths
self.valid = self.vocab.encode_file(
self.test  = self.vocab.encode_file(
# 调用：tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len,device=device, ext_len=args.ext_len)
def get_iterator(self, split, *args, **kwargs):
if split == 'train':
if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
data_iter = LMOrderedIterator(self.train, *args, **kwargs)
elif self.dataset == 'lm1b':
kwargs['shuffle'] = True
data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs)
elif split in ['valid', 'test']:
data = self.valid if split == 'valid' else self.test
if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
data_iter = LMOrderedIterator(data, *args, **kwargs)
elif self.dataset == 'lm1b':
data_iter = LMShuffledIterator(data, *args, **kwargs)

return data_iter
# 数据处理：1.读取语料 2.构建词表 3.将token转换成index
# 如果已存在cache.pt，就直接下载
if os.path.exists(fn):
else:
print('Producing dataset {}...'.format(dataset))
kwargs = {}
if dataset in ['wt103', 'wt2']:
kwargs['special'] = ['<eos>']
kwargs['lower_case'] = False
elif dataset == 'ptb':
kwargs['special'] = ['<eos>']
kwargs['lower_case'] = True
elif dataset == 'lm1b':
kwargs['special'] = []
kwargs['lower_case'] = False
elif dataset in ['enwik8', 'text8']:
pass
# 数据处理：1.读取语料 2.构建词表 3.将token转换成index
corpus = Corpus(datadir, dataset, **kwargs)# **kwargs表示它是一个可变字典
torch.save(corpus, fn)

return corpus

if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='unit test')
help='location of the data corpus')
choices=['ptb', 'wt2', 'wt103', 'lm1b', 'enwik8', 'text8'],
help='dataset name')
args = parser.parse_args()

print('Vocab size : {}'.format(len(corpus.vocab.idx2sym)))



vocabulary.py

import os
from collections import Counter, OrderedDict

import torch
# 函数调用：self.vocab = Vocab(*args, **kwargs)
class Vocab(object):
def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True,
delimiter=None, vocab_file=None):
self.counter = Counter()
self.special = special
self.min_freq = min_freq
self.max_size = max_size
self.lower_case = lower_case
self.delimiter = delimiter #分隔符
self.vocab_file = vocab_file #词表路径

# <---------对句子进行token split并且添加<eos>---->
line = line.strip()
# convert to lower case
if self.lower_case:
line = line.lower()

# empty delimiter '' will evaluate False
if self.delimiter == '':
symbols = line
else:
symbols = line.split(self.delimiter)

return ['<S>'] + symbols + ['<S>']
return symbols + ['<eos>']
else:
return symbols

# <---------读取语料---->
# 函数调用：self.vocab.count_file(os.path.join(path, 'train.txt'))
if verbose: print('counting file {} ...'.format(path))
assert os.path.exists(path)

sents = []
with open(path, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
if verbose and idx > 0 and idx % 500000 == 0:
print('    line {}'.format(idx))
#对句子进行token split并且添加<eos>
self.counter.update(symbols) # 计算词出现次数
sents.append(symbols)

return sents

def count_sents(self, sents, verbose=False):
"""
sents : a list of sentences, each a list of tokenized symbols
"""
if verbose: print('counting {} sents ...'.format(len(sents)))
for idx, symbols in enumerate(sents):
if verbose and idx > 0 and idx % 500000 == 0:
print('    line {}'.format(idx))
self.counter.update(symbols)
# 函数调用：self._build_from_file(self.vocab_file)
def _build_from_file(self, vocab_file):
self.idx2sym = []
self.sym2idx = OrderedDict()

with open(vocab_file, 'r', encoding='utf-8') as f:
for line in f:
symb = line.strip().split()[0]
self.unk_idx = self.sym2idx['<UNK>']

# <---------构建vocab---->
# <词典个数267735>
# 函数调用：self.vocab.build_vocab()
def build_vocab(self):
if self.vocab_file:
print('building vocab from {}'.format(self.vocab_file))
self._build_from_file(self.vocab_file)
print('final vocab size {}'.format(len(self)))
else:
print('building vocab with min_freq={}, max_size={}'.format(
self.min_freq, self.max_size))
self.idx2sym = []
self.sym2idx = OrderedDict()

for sym in self.special:

for sym, cnt in self.counter.most_common(self.max_size):
if cnt < self.min_freq: break

print('final vocab size {} from {} unique tokens'.format(
len(self), len(self.counter)))

# <---------将token转换成index---->
# <1801349句>
# 函数调用： self.train = self.vocab.encode_file(os.path.join(path, 'train.txt'), ordered=True)
def encode_file(self, path, ordered=False, verbose=False, add_eos=True,
if verbose: print('encoding file {} ...'.format(path))
assert os.path.exists(path)
encoded = []
with open(path, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
if verbose and idx > 0 and idx % 500000 == 0:
print('    line {}'.format(idx))
encoded.append(self.convert_to_tensor(symbols))
# [1,2,3,4]*[2,3,4,5]=[1,2,3,4,2,3,4,5]
if ordered:
encoded = torch.cat(encoded)

return encoded

def encode_sents(self, sents, ordered=False, verbose=False):
if verbose: print('encoding {} sents ...'.format(len(sents)))
encoded = []
for idx, symbols in enumerate(sents):
if verbose and idx > 0 and idx % 500000 == 0:
print('    line {}'.format(idx))
encoded.append(self.convert_to_tensor(symbols))

if ordered:
encoded = torch.cat(encoded)

return encoded

if sym not in self.sym2idx:
self.idx2sym.append(sym)
self.sym2idx[sym] = len(self.idx2sym) - 1
setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym])

if sym not in self.sym2idx:
self.idx2sym.append(sym)
self.sym2idx[sym] = len(self.idx2sym) - 1

def get_sym(self, idx):
assert 0 <= idx < len(self), 'Index {} out of range'.format(idx)
return self.idx2sym[idx]

def get_idx(self, sym):
if sym in self.sym2idx:
return self.sym2idx[sym]
else:
# print('encounter unk {}'.format(sym))
assert '<eos>' not in sym
assert hasattr(self, 'unk_idx')
return self.sym2idx.get(sym, self.unk_idx)

def get_symbols(self, indices):
return [self.get_sym(idx) for idx in indices]

def get_indices(self, symbols):
return [self.get_idx(sym) for sym in symbols]

def convert_to_tensor(self, symbols):

def convert_to_sent(self, indices, exclude=None):
if exclude is None:
return ' '.join([self.get_sym(idx) for idx in indices])
else:
return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])

def __len__(self):
return len(self.idx2sym)



## 四、模型（mem_transformer.py）

import sys
import math
import functools

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

#sys.path.append('utils')

from utils.log_uniform_sampler import LogUniformSampler,sample_logits
#from log_uniform_sampler import LogUniformSampler, sample_logits

# PositionalEmbedding
# 函数调用：self.pos_emb = PositionalEmbedding(self.d_model)
class PositionalEmbedding(nn.Module):
def __init__(self, demb):
super(PositionalEmbedding, self).__init__()

self.demb = demb

inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
self.register_buffer('inv_freq', inv_freq)

def forward(self, pos_seq, bsz=None):
sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)

if bsz is not None:
return pos_emb[:,None,:].expand(-1, bsz, -1)
else:
return pos_emb[:,None,:]

# feed-forward network
# 函数调用：self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, pre_lnorm=kwargs.get('pre_lnorm'))
class PositionwiseFF(nn.Module):
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
super(PositionwiseFF, self).__init__()

self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout

self.CoreNet = nn.Sequential(
nn.Linear(d_model, d_inner), nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(d_inner, d_model),
nn.Dropout(dropout),
)

self.layer_norm = nn.LayerNorm(d_model)

self.pre_lnorm = pre_lnorm

def forward(self, inp):
if self.pre_lnorm:
##### layer normalization + positionwise feed-forward
core_out = self.CoreNet(self.layer_norm(inp))

##### residual connection
output = core_out + inp
else:
##### positionwise feed-forward
core_out = self.CoreNet(inp)

##### residual connection + layer normalization
output = self.layer_norm(inp + core_out)

return output

pre_lnorm=False):

self.d_model = d_model
self.dropout = dropout

self.drop = nn.Dropout(dropout)
self.dropatt = nn.Dropout(dropatt)

self.layer_norm = nn.LayerNorm(d_model)

self.scale = 1 / (d_head ** 0.5)

self.pre_lnorm = pre_lnorm

if mems is not None:
c = torch.cat([mems, h], 0)
else:
c = h

if self.pre_lnorm:
##### layer normalization
c = self.layer_norm(c)

# [qlen x klen x bsz x n_head]
attn_score.mul_(self.scale)

# [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob)

attn_vec = attn_vec.contiguous().view(

##### linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)

if self.pre_lnorm:
##### residual connection
output = h + attn_out
else:
##### residual connection + layer normalization
output = self.layer_norm(h + attn_out)

return output
# self-attention 继承这个类
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False):

self.d_model = d_model
self.dropout = dropout
# 初始化q,k,v=[146,60,41*10]
# embedding * Q ,embedding * k,embedding * V

self.drop = nn.Dropout(dropout)
self.dropatt = nn.Dropout(dropatt)

self.layer_norm = nn.LayerNorm(d_model)

self.scale = 1 / (d_head ** 0.5)

self.pre_lnorm = pre_lnorm

m = min(h, w)

if left:
else:

def _shift(self, x, qlen, klen, mask, left=False):
if qlen > 1:
zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
device=x.device, dtype=x.dtype)
else:

if left:
else:

.view(qlen, klen, x.size(2), x.size(3))

return x
# 函数调用：BD = self._rel_shift(BD)
# 小trick real-shift
def _rel_shift(self, x, zero_triu=False):
# x:[qlen , rlen , bsz , n_head]=[146,296,60,10]
# 0矩阵:[146,1,60,10]
device=x.device, dtype=x.dtype)
# x:[146, 296, 60, 10] rel_shift

# zero_triu=False
if zero_triu:
ones = torch.ones((x.size(0), x.size(1)))
x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None]

return x

def forward(self, w, r, attn_mask=None, mems=None):
raise NotImplementedError
# self-attention层 相对位置编码
def __init__(self, *args, **kwargs):
# 初始化( pistional encoding) r_head_k = [164,60,41*10]

def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):
# w表示embedding,r表示pistional encoding
# w:[146,60,410] r:[146,1,60]
#  qlen=146 rlen=146/296  bsz=60
qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
# 初始化q,k,v=[146,60,41*10]
# 初始化( pistional encoding) r_head_k = [164,60,41*10]
if mems is not None:
cat = torch.cat([mems, w], 0)
if self.pre_lnorm:
else:
else:
# pre_lnorm=False
if self.pre_lnorm:
else:

#### compute attention score
#AC=[146,146/296,60,10]

#BD=[146,146/296,60,10]
# 小trick real-shift
BD = self._rel_shift(BD)

# [qlen x klen x bsz x n_head]
attn_score = AC + BD
# 除根号dk
attn_score.mul_(self.scale)

#### compute attention probability
# 极小值填充

# softmax
# [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob)

# softmax * v

attn_vec = attn_vec.contiguous().view(

##### linear projection
# a * Wo
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
# pre_lnorm = False
if self.pre_lnorm:
##### residual connection
output = w + attn_out
else:
##### residual connection + layer normalization
output = self.layer_norm(w + attn_out)

return output

def __init__(self, *args, **kwargs):

def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None):
# r_bias: [klen, n_head], used for term D

qlen, bsz = w.size(0), w.size(1)

if mems is not None:
cat = torch.cat([mems, w], 0)
if self.pre_lnorm:
else:

else:
if self.pre_lnorm:
else:

if klen > r_emb.size(0):
else:
r_emb = r_emb[-klen:]
r_bias = r_bias[-klen:]

#### compute attention score

B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb))                  # qlen x klen x bsz x n_head
D_ = r_bias[None, :, None]                                              # 1    x klen x 1   x n_head
BD = self._rel_shift(B_ + D_)

# [qlen x klen x bsz x n_head]
attn_score = AC + BD
attn_score.mul_(self.scale)

#### compute attention probability

# [qlen x klen x bsz x n_head]
attn_prob = F.softmax(attn_score, dim=1)
attn_prob = self.dropatt(attn_prob)

#### compute attention vector

attn_vec = attn_vec.contiguous().view(

##### linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)

if self.pre_lnorm:
##### residual connection
output = w + attn_out
else:
##### residual connection + layer normalization
output = self.layer_norm(w + attn_out)

return output

class DecoderLayer(nn.Module):
super(DecoderLayer, self).__init__()

self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))

mems=mems)
output = self.pos_ff(output)

return output

class RelLearnableDecoderLayer(nn.Module):
**kwargs):
super(RelLearnableDecoderLayer, self).__init__()

**kwargs)
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))

def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):

output = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias,
mems=mems)
output = self.pos_ff(output)

return output
# 对每层定义sublayer: self-attention , feed-forward network
# 函数调用：RelPartialLearnableDecoderLayer(
#                         tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
#                         dropatt=dropatt, pre_lnorm=pre_lnorm)
class RelPartialLearnableDecoderLayer(nn.Module):
**kwargs):
super(RelPartialLearnableDecoderLayer, self).__init__()
# # self-attention层 相对位置编码
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
pre_lnorm=kwargs.get('pre_lnorm'))

def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):

output = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias,
mems=mems)
output = self.pos_ff(output)

return output

# 函数调用：self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs,div_val=div_val)
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
sample_softmax=False):

self.n_token = n_token
self.d_embed = d_embed

self.cutoffs = cutoffs + [n_token]
self.div_val = div_val
self.d_proj = d_proj

self.emb_scale = d_proj ** 0.5

self.cutoff_ends = [0] + self.cutoffs

self.emb_layers = nn.ModuleList()
self.emb_projs = nn.ParameterList()
if div_val == 1:
self.emb_layers.append(
nn.Embedding(n_token, d_embed, sparse=sample_softmax>0)
)
if d_proj != d_embed:
self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed)))
else:
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
d_emb_i = d_embed // (div_val ** i)
self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i))
self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_emb_i)))

def forward(self, inp):
if self.div_val == 1:
embed = self.emb_layers[0](inp)
if self.d_proj != self.d_embed:
embed  = F.linear(embed, self.emb_projs[0])
else:
param = next(self.parameters())
inp_flat = inp.view(-1)
emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],
dtype=param.dtype, device=param.device)
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]

mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)

if indices_i.numel() == 0:
continue

inp_i = inp_flat.index_select(0, indices_i) - l_idx
emb_i = self.emb_layers[i](inp_i)
emb_i = F.linear(emb_i, self.emb_projs[i])

emb_flat.index_copy_(0, indices_i, emb_i)

embed = emb_flat.view(*inp.size(), self.d_proj)

embed.mul_(self.emb_scale)

return embed
# transformer-xl模型： self-attention , feed-forward network，Relative Positional Embedding
# 函数调用： model = MemTransformerLM(ntokens, args.n_layer, args.n_head, args.d_model,
#         tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val,
#         tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len,
#         ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs,
#         same_length=args.same_length, attn_type=args.attn_type,
#         clamp_len=args.clamp_len, sample_softmax=args.sample_softmax)
class MemTransformerLM(nn.Module):
dropout, dropatt, tie_weight=True, d_embed=None,
div_val=1, tie_projs=[False], pre_lnorm=False,
tgt_len=None, ext_len=None, mem_len=None,
same_length=False, attn_type=0, clamp_len=-1,
sample_softmax=-1):
super(MemTransformerLM, self).__init__()
self.n_token = n_token
# d_embed=410
d_embed = d_model if d_embed is None else d_embed
self.d_embed = d_embed
# d_model=410
self.d_model = d_model

self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs,
div_val=div_val)

self.drop = nn.Dropout(dropout)
# n_layer=16
self.n_layer = n_layer
# tgt_len=150 需要predict的token数目
self.tgt_len = tgt_len
# mem_len=150 memory的长度
self.mem_len = mem_len
# ext_len=0 是extended content的长度
self.ext_len = ext_len
self.max_klen = tgt_len + ext_len + mem_len
#定义模型的类型 0代表transfoemer-xl
self.attn_type = attn_type
# 对每层定义sublayer self-attention , feed-forward network
self.layers = nn.ModuleList()
if attn_type == 0: # the default attention
for i in range(n_layer):
self.layers.append(
RelPartialLearnableDecoderLayer(
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
dropatt=dropatt, pre_lnorm=pre_lnorm)
)
elif attn_type == 1: # learnable embeddings
for i in range(n_layer):
self.layers.append(
RelLearnableDecoderLayer(
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
dropatt=dropatt, pre_lnorm=pre_lnorm)
)
elif attn_type in [2, 3]: # absolute embeddings
for i in range(n_layer):
self.layers.append(
DecoderLayer(
dropatt=dropatt, pre_lnorm=pre_lnorm)
)
# sample_softmax=-1
self.sample_softmax = sample_softmax
# use sampled softmax
if sample_softmax > 0:
self.out_layer = nn.Linear(d_model, n_token)
if tie_weight:
self.out_layer.weight = self.word_emb.weight
self.tie_weight = tie_weight
self.sampler = LogUniformSampler(n_token, sample_softmax)

# use adaptive softmax (including standard softmax)
else:
cutoffs, div_val=div_val)
# tie_weight=True
if tie_weight:
for i in range(len(self.crit.out_layers)):
self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight
# tie_projs=[False]
if tie_projs:
for i, tie_proj in enumerate(tie_projs):
if tie_proj and div_val == 1 and d_model != d_embed:
self.crit.out_projs[i] = self.word_emb.emb_projs[0]
elif tie_proj and div_val != 1:
self.crit.out_projs[i] = self.word_emb.emb_projs[i]
# same_length=False
self.same_length = same_length
# clamp_len=-1
self.clamp_len = clamp_len
# 得到 Relative Positional Embedding
self._create_params()

def backward_compatible(self):
self.sample_softmax = -1

# 得到 Relative Positional Embedding
# 函数调用：self._create_params()
def _create_params(self):
# attn_type=0，代表transfoemer-xl
if self.attn_type == 0: # default attention
# PositionalEmbedding
self.pos_emb = PositionalEmbedding(self.d_model)
elif self.attn_type == 1: # learnable
self.r_emb = nn.Parameter(torch.Tensor(
self.r_w_bias = nn.Parameter(torch.Tensor(
self.r_bias = nn.Parameter(torch.Tensor(
elif self.attn_type == 2: # absolute standard
self.pos_emb = PositionalEmbedding(self.d_model)
elif self.attn_type == 3: # absolute deeper SA
self.r_emb = nn.Parameter(torch.Tensor(

def reset_length(self, tgt_len, ext_len, mem_len):
self.tgt_len = tgt_len
self.mem_len = mem_len
self.ext_len = ext_len

def init_mems(self):
if self.mem_len > 0:
mems = []
param = next(self.parameters())
for i in range(self.n_layer+1):
empty = torch.empty(0, dtype=param.dtype, device=param.device)
mems.append(empty)

return mems
else:
return None

# 更新 mems
# 函数调用：new_mems = self._update_mems(hids, mems, mlen, qlen)
def _update_mems(self, hids, mems, qlen, mlen):
# does not deal with None
if mems is None: return None

# mems is not None
assert len(hids) == len(mems), 'len(hids) != len(mems)'

# There are mlen + qlen steps that can be cached into mems
# For the next step, the last ext_len of the qlen tokens
# will be used as the extended context. Hence, we only cache
# the tokens from mlen + qlen - self.ext_len - self.mem_len
# to mlen + qlen - self.ext_len.
new_mems = []
# mlen=146,qlen=0
# end_idx=146+max(0,0-0)=146
# end_idx=146+max(0,150)=296
end_idx = mlen + max(0, qlen - 0 - self.ext_len)
# beg_idx=max(0,146-150)=-4
# beg_idx=max(0,296-150)=146
beg_idx = max(0, end_idx - self.mem_len)
# hids:[146,60,410]
for i in range(len(hids)):

cat = torch.cat([mems[i], hids[i]], dim=0)
new_mems.append(cat[beg_idx:end_idx].detach())

return new_mems

# self-attention ， feed-forward network
# 函数调用： hidden, new_mems = self._forward(data, mems=mems)
def _forward(self, dec_inp, mems=None):
# qlen=146, bsz=60
qlen, bsz = dec_inp.size()#qlen：max_len
# word_emb：[146,60,410]
word_emb = self.word_emb(dec_inp)
# mlen=0
mlen = mems[0].size(0) if mems is not None else 0
# klen=0+146=146
# klen=146+150=296
klen = mlen + qlen#总的长度
# same_length=True
if self.same_length:
# word_emb置1
all_ones = word_emb.new_ones(qlen, klen)
else:
# 形成下三角矩阵
+ torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1
else:
word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]

hids = []
# attn_type=0
if self.attn_type == 0: # default
#生成向量，维度等于word embedding
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype)
# clamp_len=-1
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)
# pos_emb:[146,1,410]
pos_emb = self.pos_emb(pos_seq)
# core_out:[146,60,410]
core_out = self.drop(word_emb)
pos_emb = self.drop(pos_emb)

hids.append(core_out)
# 进入网络层：16层sublayer: self-attention , feed-forward network
for i, layer in enumerate(self.layers):
mems_i = None if mems is None else mems[i]
core_out = layer(core_out, pos_emb, self.r_w_bias,
hids.append(core_out)
elif self.attn_type == 1: # learnable
core_out = self.drop(word_emb)
hids.append(core_out)
for i, layer in enumerate(self.layers):
if self.clamp_len > 0:
r_emb = self.r_emb[i][-self.clamp_len :]
r_bias = self.r_bias[i][-self.clamp_len :]
else:
r_emb, r_bias = self.r_emb[i], self.r_bias[i]

mems_i = None if mems is None else mems[i]
core_out = layer(core_out, r_emb, self.r_w_bias[i],
hids.append(core_out)
elif self.attn_type == 2: # absolute
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
dtype=word_emb.dtype)
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)
pos_emb = self.pos_emb(pos_seq)

core_out = self.drop(word_emb + pos_emb[-qlen:])

hids.append(core_out)
for i, layer in enumerate(self.layers):
mems_i = None if mems is None else mems[i]
if mems_i is not None and i == 0:
mems_i += pos_emb[:mlen]
mems=mems_i)
hids.append(core_out)
elif self.attn_type == 3:
core_out = self.drop(word_emb)

hids.append(core_out)
for i, layer in enumerate(self.layers):
mems_i = None if mems is None else mems[i]
if mems_i is not None and mlen > 0:
cur_emb = self.r_emb[i][:-qlen]
cur_size = cur_emb.size(0)
if cur_size < mlen:
else:
cur_emb = cur_emb[-mlen:]
mems_i += cur_emb.view(mlen, 1, -1)
core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)

mems=mems_i)
hids.append(core_out)

core_out = self.drop(core_out)
# 更新 mens
new_mems = self._update_mems(hids, mems, mlen, qlen)

return core_out, new_mems

def forward(self, data, target, *mems):
# mems初始化
if not mems: mems = self.init_mems()
# data, target：[146,60]
#tgt_len=146
tgt_len = target.size(0)
# 训练模型：self-attention  feed-forward network
hidden, new_mems = self._forward(data, mems=mems)

pred_hid = hidden[-tgt_len:]
# sample_softmax=-1
if self.sample_softmax > 0 and self.training:
assert self.tie_weight
logit = sample_logits(self.word_emb,
self.out_layer.bias, target, pred_hid, self.sampler)
loss = -F.log_softmax(logit, -1)[:, :, 0]
else:
# 计算loss
loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))
loss = loss.view(tgt_len, -1)

if new_mems is None:
return [loss]
else:
return [loss] + new_mems

if __name__ == '__main__':
import argparse

parser = argparse.ArgumentParser(description='unit test')

args = parser.parse_args()

device = torch.device("cuda" if args.cuda else "cpu")

B = 4
tgt_len, mem_len, ext_len = 36, 36, 0
data_len = tgt_len * 20
args.n_token = 10000

import data_utils

data = torch.LongTensor(data_len*B).random_(0, args.n_token).to(device)
diter = data_utils.LMOrderedIterator(data, B, tgt_len, device=device, ext_len=ext_len)

cutoffs = [args.n_token // 2]
tie_projs = [False] + [True] * len(cutoffs)

for div_val in [1, 2]:
for d_embed in [200, 100]:
dropatt=args.dropout, tie_weight=True,
d_embed=d_embed, div_val=div_val,
tie_projs=tie_projs, pre_lnorm=True,
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
cutoffs=cutoffs, attn_type=0).to(device)

print(sum(p.numel() for p in model.parameters()))

mems = tuple()
for idx, (inp, tgt, seqlen) in enumerate(diter):
print('batch {}'.format(idx))
out = model(inp, tgt, *mems)
mems = out[1:]



## 五、训练(train.py)

# coding: utf-8
import argparse
import time
import math
import os, sys
import itertools
import args as args
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

from data_utils import get_lm_corpus
from mem_transformer import MemTransformerLM
from utils.exp_utils import create_exp_dir
from utils.data_parallel import BalancedDataParallel

parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
help='location of the data corpus')
choices=['wt103', 'lm1b', 'enwik8', 'text8'],
help='dataset name')
help='number of total layers')
help='embedding dimension')
help='model dimension')
help='inner dimension in FF')
help='global dropout rate')
help='attention probability dropout rate')
help='parameter initializer to use.')
help='parameter initializer to use.')
help='parameters initialized by U(-init_range, init_range)')
help='parameters initialized by U(-init_range, init_range)')
help='parameters initialized by N(0, init_std)')
help='parameters initialized by N(0, init_std)')
help='optimizer to use.')
help='initial learning rate (0.00025|5 for adam|sgd)')
help='momentum for sgd')
choices=['cosine', 'inv_sqrt', 'dev_perf', 'constant'],
help='lr scheduler to use.')
help='upper epoch limit')
help='decay factor when ReduceLROnPlateau is used')
help='minimum learning rate during annealing')
help='only clip the gradient of non-embedding params')
help='upper epoch limit')
help='batch size')
help='split batch into chunks to save memory')
help='number of tokens to predict')
help='number of tokens to predict for evaluation')
help='length of the extended context')
help='length of the retained previous heads')
help='do not tie the word embedding and softmax weights')
help='random seed')
help='use CUDA')
help='divident value for adapative input and softmax')
help='apply LayerNorm to the input instead of the output')
help='use variable length')
help='use multiple GPU')
help='report interval')
help='evaluation interval')
help='experiment directory.')
help='restart training from the saved checkpoint')
help='restart dir')
help='run in debug mode (do not create exp dir)')
help='use the same attn length for all tokens')
help='attention type. 0 for ours, 1 for Shaw et al,'
'2 for Vaswani et al, 3 for Al Rfou et al.')
help='use the same pos embeddings after clamp_len')
help='min learning rate for cosine scheduler')
help='batch size on gpu 0')
help='max eval steps')
help='number of samples in sampled softmax')
help='patience')
help='finetune v2')
help='finetune v3')
help='Run in pseudo-fp16 mode (fp16 storage fp32 math).')
help='Static loss scale, positive power of 2 values can '
'improve fp16 convergence.')
help='Use dynamic loss scaling.  If supplied, this argument'
' supersedes --static-loss-scale.')
# args = parser.parse_args()
args.tied = False

# if args.d_embed < 0:
args.d_embed = args.d_model

# assert args.ext_len >= 0, 'extended context length must be non-negative'
# assert args.batch_size % args.batch_chunk == 0
args.work_dir="LM-TFM"
args.work_dir = '{}-{}'.format(args.work_dir, args.dataset)
args.work_dir = os.path.join(args.work_dir, time.strftime('%Y%m%d-%H%M%S'))
##########打印日志#############
logging = create_exp_dir(args.work_dir,
scripts_to_save=['train.py', 'mem_transformer.py'], debug=True)

# Set the random seed manually for reproducibility.
args.seed=1111
args.cuda=True
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
if not args.cuda:
print('WARNING: You have a CUDA device, so you should probably run with --cuda')
else:
torch.cuda.manual_seed_all(args.seed)

# Validate --fp16 option
args.fp16=True
if args.fp16:
if not args.cuda:
print('WARNING: --fp16 requires --cuda, ignoring --fp16 option')
args.fp16 = False
else:
try:
from apex.fp16_utils import FP16_Optimizer
except:
print('WARNING: apex not installed, ignoring --fp16 option')
args.fp16 = False

device = torch.device('cuda' if args.cuda else 'cpu')

###############################################################################
###############################################################################
# 数据处理：1.读取语料 2.构建词表 3.将token转换成index
corpus = get_lm_corpus(args.data, args.dataset)
ntokens = len(corpus.vocab)
args.n_token = ntokens

eval_batch_size = 10
args.ext_len=0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 生成bach,分批训练
tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len,
device=device, ext_len=args.ext_len)
va_iter = corpus.get_iterator('valid', eval_batch_size, args.eval_tgt_len,
device=device, ext_len=args.ext_len)
te_iter = corpus.get_iterator('test', eval_batch_size, args.eval_tgt_len,
device=device, ext_len=args.ext_len)

# 自适应softmax/嵌入
cutoffs, tie_projs = [], [False]
assert args.dataset in ['wt103', 'lm1b']
if args.dataset == 'wt103':
cutoffs = [20000, 40000, 200000]
tie_projs += [True] * len(cutoffs)
elif args.dataset == 'lm1b':
cutoffs = [60000, 100000, 640000]
tie_projs += [False] * len(cutoffs)

###############################################################################
# Build the model
###############################################################################
args.init='normal'
args.init_std=0.02
def init_weight(weight):
if args.init == 'uniform':
nn.init.uniform_(weight, -args.init_range, args.init_range)
elif args.init == 'normal':
nn.init.normal_(weight, 0.0, args.init_std)

def init_bias(bias):
nn.init.constant_(bias, 0.0)

def weights_init(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
if hasattr(m, 'weight') and m.weight is not None:
init_weight(m.weight)
if hasattr(m, 'bias') and m.bias is not None:
init_bias(m.bias)
if hasattr(m, 'emb_projs'):
for i in range(len(m.emb_projs)):
if m.emb_projs[i] is not None:
nn.init.normal_(m.emb_projs[i], 0.0, args.proj_init_std)
elif classname.find('Embedding') != -1:
if hasattr(m, 'weight'):
init_weight(m.weight)
if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
init_weight(m.cluster_weight)
if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
init_bias(m.cluster_bias)
if hasattr(m, 'out_projs'):
for i in range(len(m.out_projs)):
if m.out_projs[i] is not None:
nn.init.normal_(m.out_projs[i], 0.0, args.proj_init_std)
elif classname.find('LayerNorm') != -1:
if hasattr(m, 'weight'):
nn.init.normal_(m.weight, 1.0, args.init_std)
if hasattr(m, 'bias') and m.bias is not None:
init_bias(m.bias)
elif classname.find('TransformerLM') != -1:
if hasattr(m, 'r_emb'):
init_weight(m.r_emb)
if hasattr(m, 'r_w_bias'):
init_weight(m.r_w_bias)
if hasattr(m, 'r_r_bias'):
init_weight(m.r_r_bias)
if hasattr(m, 'r_bias'):
init_bias(m.r_bias)

def update_dropout(m):
classname = m.__class__.__name__
if classname.find('Dropout') != -1:
if hasattr(m, 'p'):
m.p = args.dropout

def update_dropatt(m):
if hasattr(m, 'dropatt'):
m.dropatt.p = args.dropatt

args.restart=False
args.dropatt=0.0
args.div_val=1
args.pre_lnorm=True
args.same_length=True
args.attn_type=0
args.clamp_len=-1
args.sample_softmax=-1
if args.restart:
with open(os.path.join(args.restart_dir, 'model.pt'), 'rb') as f:
if not args.fp16:
model = model.float()
model.apply(update_dropout)
model.apply(update_dropatt)
else:
# 创建transformer-xl模型： self-attention , feed-forward network，Relative Positional Embedding
model = MemTransformerLM(ntokens, args.n_layer, args.n_head, args.d_model,
tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val,
tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len,
ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs,
same_length=args.same_length, attn_type=args.attn_type,
clamp_len=args.clamp_len, sample_softmax=args.sample_softmax)
model.apply(weights_init)
model.word_emb.apply(weights_init) # 在权重共享的情况下，确保嵌入init不被out_层覆盖
# 优化器
args.n_all_param = sum([p.nelement() for p in model.parameters()])
args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])

if args.fp16:
model = model.half()
args.multi_gpu=False
# args.multi_gpu=False
if args.multi_gpu:
model = model.to(device)
if args.gpu0_bsz >= 0:
para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk,
model, dim=1).to(device)
else:
para_model = nn.DataParallel(model, dim=1).to(device)
else:
# model来源： model = MemTransformerLM(）
para_model = model.to(device)

#### optimizer
if args.optim.lower() == 'sgd':
if args.sample_softmax > 0:
dense_params, sparse_params = [], []
for param in model.parameters():
if param.size() == model.word_emb.weight.size():
sparse_params.append(param)
else:
dense_params.append(param)
optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2)
optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom)
else:
optimizer = optim.SGD(model.parameters(), lr=args.lr,
momentum=args.mom)
if args.sample_softmax > 0:
dense_params, sparse_params = [], []
for param in model.parameters():
if param.size() == model.word_emb.weight.size():
sparse_params.append(param)
else:
dense_params.append(param)
else:

args.scheduler = 'cosine'
args.eta_min=0.0
#### scheduler
if args.scheduler == 'cosine':
# here we do not set eta_min to lr_min to be backward compatible
# because in previous versions eta_min is default to 0
# rather than the default value of lr_min 1e-6
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
args.max_step, eta_min=args.eta_min) # should use eta_min arg
if args.sample_softmax > 0:
scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(optimizer_sparse,
args.max_step, eta_min=args.eta_min) # should use eta_min arg
elif args.scheduler == 'inv_sqrt':
# originally used for Transformer (in Attention is all you need)
def lr_lambda(step):
# return a multiplier instead of a learning rate
if step == 0 and args.warmup_step == 0:
return 1.
else:
return 1. / (step ** 0.5) if step > args.warmup_step \
else step / (args.warmup_step ** 1.5)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
elif args.scheduler == 'dev_perf':
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min)
if args.sample_softmax > 0:
scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(optimizer_sparse,
factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min)
elif args.scheduler == 'constant':
pass

if args.cuda and args.fp16:
# If args.dynamic_loss_scale is False, static_loss_scale will be used.
# If args.dynamic_loss_scale is True, it will take precedence over static_loss_scale.
optimizer = FP16_Optimizer(optimizer,
static_loss_scale = args.static_loss_scale,
dynamic_loss_scale = args.dynamic_loss_scale,
dynamic_loss_args = {'init_scale': 2 ** 16})

if args.restart:
if os.path.exists(os.path.join(args.restart_dir, 'optimizer.pt')):
with open(os.path.join(args.restart_dir, 'optimizer.pt'), 'rb') as f:
else:
print('Optimizer was not saved. Start from scratch.')

logging('=' * 100)
for k, v in args.__dict__.items():
logging('    - {} : {}'.format(k, v))
logging('=' * 100)
logging('#params = {}'.format(args.n_all_param))
logging('#non emb params = {}'.format(args.n_nonemb_param))

###############################################################################
# Training code
###############################################################################

def evaluate(eval_iter):
# Turn on evaluation mode which disables dropout.
model.eval()

# If the model does not use memory at all, make the ext_len longer.
# Otherwise, make the mem_len longer and keep the ext_len the same.
if args.mem_len == 0:
model.reset_length(args.eval_tgt_len,
args.ext_len+args.tgt_len-args.eval_tgt_len, args.mem_len)
else:
model.reset_length(args.eval_tgt_len,
args.ext_len, args.mem_len+args.tgt_len-args.eval_tgt_len)

# Evaluation
total_len, total_loss = 0, 0.
mems = tuple()
for i, (data, target, seq_len) in enumerate(eval_iter):
if args.max_eval_steps > 0 and i >= args.max_eval_steps:
break
ret = model(data, target, *mems)
loss, mems = ret[0], ret[1:]
loss = loss.mean()
total_loss += seq_len * loss.float().item()
total_len += seq_len

# Switch back to the training mode
model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
model.train()

args.batch_chunk=1
args.varlen=True
args.clip=0.25
args.log_interval=200
args.eval_interval=4000

# 模型训练
def train():
# Turn on training mode which enables dropout.
global train_step, train_loss, best_val_loss, eval_start_time, log_start_time
model.train()
# args.batch_chunk=1
if args.batch_chunk > 1:
mems = [tuple() for _ in range(args.batch_chunk)]
else:
mems = tuple()
# 数据处理：1.读取语料 2.构建词表 3.将token转换成index
# 生成bach,分批训练
train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter

for batch, (data, target, seq_len) in enumerate(train_iter):
# args.batch_chunk=1
if args.batch_chunk > 1:
data_chunks = torch.chunk(data, args.batch_chunk, 1)
target_chunks = torch.chunk(target, args.batch_chunk, 1)
for i in range(args.batch_chunk):
data_i = data_chunks[i].contiguous()
target_i = target_chunks[i].contiguous()
ret = para_model(data_i, target_i, *mems[i])
loss, mems[i] = ret[0], ret[1:]
loss = loss.float().mean().type_as(loss) / args.batch_chunk
if args.fp16:
optimizer.backward(loss)
else:
loss.backward()
train_loss += loss.float().item()
# para_model来源： para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk,model, dim=1).to(device)
else:
# 模型
ret = para_model(data, target, *mems)
# 获得loss 和 mems
loss, mems = ret[0], ret[1:]
loss = loss.float().mean().type_as(loss)
if args.fp16:
optimizer.backward(loss)
else:
loss.backward()
train_loss += loss.float().item()

if args.fp16:
else:

optimizer.step()
if args.sample_softmax > 0:
optimizer_sparse.step()

# step-wise learning rate annealing
train_step += 1
if args.scheduler in ['cosine', 'constant', 'dev_perf']:
# linear warmup stage
if train_step < args.warmup_step:
curr_lr = args.lr * train_step / args.warmup_step
optimizer.param_groups[0]['lr'] = curr_lr
if args.sample_softmax > 0:
optimizer_sparse.param_groups[0]['lr'] = curr_lr * 2
else:
if args.scheduler == 'cosine':
scheduler.step(train_step)
if args.sample_softmax > 0:
scheduler_sparse.step(train_step)
elif args.scheduler == 'inv_sqrt':
scheduler.step(train_step)

if train_step % args.log_interval == 0:
cur_loss = train_loss / args.log_interval
elapsed = time.time() - log_start_time
log_str = '| epoch {:3d} step {:>8d} | {:>6d} batches | lr {:.3g} ' \
'| ms/batch {:5.2f} | loss {:5.2f}'.format(
epoch, train_step, batch+1, optimizer.param_groups[0]['lr'],
elapsed * 1000 / args.log_interval, cur_loss)
if args.dataset in ['enwik8', 'text8']:
log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2))
else:
log_str += ' | ppl {:9.3f}'.format(math.exp(cur_loss))
logging(log_str)
train_loss = 0
log_start_time = time.time()

if train_step % args.eval_interval == 0:
val_loss = evaluate(va_iter)
logging('-' * 100)
log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
'| valid loss {:5.2f}'.format(
train_step // args.eval_interval, train_step,
(time.time() - eval_start_time), val_loss)
if args.dataset in ['enwik8', 'text8']:
log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2))
else:
log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss))
logging(log_str)
logging('-' * 100)
# Save the model if the validation loss is the best we've seen so far.
if not best_val_loss or val_loss < best_val_loss:
if not args.debug:
with open(os.path.join(args.work_dir, 'model.pt'), 'wb') as f:
torch.save(model, f)
with open(os.path.join(args.work_dir, 'optimizer.pt'), 'wb') as f:
torch.save(optimizer.state_dict(), f)
best_val_loss = val_loss

# dev-performance based learning rate annealing
if args.scheduler == 'dev_perf':
scheduler.step(val_loss)
if args.sample_softmax > 0:
scheduler_sparse.step(val_loss)

eval_start_time = time.time()

if train_step == args.max_step:
break

# Loop over epochs.
train_step = 0
train_loss = 0
best_val_loss = None

log_start_time = time.time()
eval_start_time = time.time()

# At any point you can hit Ctrl + C to break out of training early.
try:
for epoch in itertools.count(start=1):
train()
if train_step == args.max_step:
logging('-' * 100)
logging('End of training')
break
except KeyboardInterrupt:
logging('-' * 100)
logging('Exiting from training early')

# Load the best saved model.
with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f:
para_model = model.to(device)

# Run on test data.
test_loss = evaluate(te_iter)
logging('=' * 100)
if args.dataset in ['enwik8', 'text8']:
logging('| End of training | test loss {:5.2f} | test bpc {:9.5f}'.format(
test_loss, test_loss / math.log(2)))
else:
logging('| End of training | test loss {:5.2f} | test ppl {:9.3f}'.format(
test_loss, math.exp(test_loss)))
logging('=' * 100)



from collections import defaultdict

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

# CUDA_MAJOR = int(torch.version.cuda.split('.')[0])
# CUDA_MINOR = int(torch.version.cuda.split('.')[1])

# 函数调用：self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model, cutoffs, div_val=div_val)
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
keep_order=False):
# n_token 代表词典里的词语个数
self.n_token = n_token
# embedding维度 410
self.d_embed = d_embed
# 线性层：410
self.d_proj = d_proj
# v1，v2,v3,v4:[20000, 40000, 200000，267735]
self.cutoffs = cutoffs + [n_token]
# [0,20000, 40000, 200000，267735]
self.cutoff_ends = [0] + self.cutoffs
# div_val=1
self.div_val = div_val

# 20000
self.shortlist_size = self.cutoffs[0]
# 3
self.n_clusters = len(self.cutoffs) - 1
# 20003

# self.n_clusters=3
if self.n_clusters > 0:
self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed))
self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))

# 初始化 out_projs,out_layers
self.out_layers = nn.ModuleList()
self.out_projs = nn.ParameterList()

# 定义 out_projs:4层 ，out_layers:线性层[d_embed, n_token]
# div_val=1
if div_val == 1:
# out_projs : 4层
for i in range(len(self.cutoffs)):
if d_proj != d_embed:
self.out_projs.append(
nn.Parameter(torch.Tensor(d_proj, d_embed))
)
else:
self.out_projs.append(None)
# out_layers ： 线性层[d_embed, n_token]
self.out_layers.append(nn.Linear(d_embed, n_token))
else:
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
d_emb_i = d_embed // (div_val ** i)

self.out_projs.append(
nn.Parameter(torch.Tensor(d_proj, d_emb_i))
)

self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx))

self.keep_order = keep_order
# 进入线性层
def _compute_logit(self, hidden, weight, bias, proj):
#映射成前面矩阵
if proj is None:
logit = F.linear(hidden, weight, bias=bias)
else:
# if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1:
proj_hid = F.linear(hidden, proj.t().contiguous())
logit = F.linear(proj_hid, weight, bias=bias)
# else:
#     logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))
#     if bias is not None:
#         logit = logit + bias

return logit

def forward(self, hidden, target, keep_order=False):
'''
hidden :: [len*bsz x d_proj]
target :: [len*bsz]
'''

if hidden.size(0) != target.size(0):
raise RuntimeError('Input and target should have the same size '
'in the batch dimension.')

# n_clusters=3
if self.n_clusters == 0:
logit = self._compute_logit(hidden, self.out_layers[0].weight,
self.out_layers[0].bias, self.out_projs[0])
nll = -F.log_softmax(logit, dim=-1) \
.gather(1, target.unsqueeze(1)).squeeze(1)
else:
# construct weights and biases
weights, biases = [], []
# v1，v2,v3,v4:[20000, 40000, 200000，267735]
for i in range(len(self.cutoffs)):
# div_val=1
if self.div_val == 1:
# 截取对应的weight和bias
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
weight_i = self.out_layers[0].weight[l_idx:r_idx]
bias_i = self.out_layers[0].bias[l_idx:r_idx]
else:
weight_i = self.out_layers[i].weight
bias_i = self.out_layers[i].bias

if i == 0:
weight_i = torch.cat(
[weight_i, self.cluster_weight], dim=0)
bias_i = torch.cat(
[bias_i, self.cluster_bias], dim=0)

# 将4个weight和bias放入weights和biases
weights.append(weight_i)
biases.append(bias_i)

# hidden对应于预测[148*60,410]  [len*bsz , d_proj]

# 初始换0矩阵：[148*60] len*bsz]
nll = torch.zeros_like(target,
dtype=hidden.dtype, device=hidden.device)

offset = 0
# [0,20000, 40000, 200000，267735]
cutoff_values = [0] + self.cutoffs
for i in range(len(cutoff_values) - 1):
# 0,20000
l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]

mask_i = (target >= l_idx) & (target < r_idx)
#取出当前区间范围的index

if indices_i.numel() == 0:
continue
# target对应148*60=8880 target_i 对应的是8309
target_i = target.index_select(0, indices_i) - l_idx

if i == 0:
# target_i[:,None]=[8800,1]
# 取出对应的target位置的logprob
else:
weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]

# 对应于低频的词，先直接计算hidden在这个区间位置的hidden_i
hidden_i = hidden.index_select(0, indices_i)

# 然后计算logprob
tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)

+ tail_logprob_i.gather(1, target_i[:,None]).squeeze(1)

if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
nll.index_copy_(0, indices_i, -logprob_i)
else:
nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)

offset += logprob_i.size(0)

return nll



## 实验结果

• 0
点赞
• 1
收藏
觉得还不错? 一键收藏
• 1
评论
06-11 661
01-31 987
09-27 406

1.余额是钱包充值的虚拟货币，按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载，可以购买VIP、付费专栏及课程。