Pytorch深度学习(2) -- RNN及其进阶模型实现
0 预测训练函数
def train_and_predict_rnn(rnn, get_params, init_rnn_state, num_hiddens,
vocab_size, device, corpus_indices, idx_to_char,
char_to_idx, is_random_iter, num_epochs, num_steps,
lr, clipping_theta, batch_size, pred_period,
pred_len, prefixes):
if is_random_iter:
data_iter_fn = d2l.data_iter_random
else:
data_iter_fn = d2l.data_iter_consecutive
params = get_params()
loss = nn.CrossEntropyLoss()
for epoch in range(num_epochs):
if not is_random_iter:
state = init_rnn_state(batch_size, num_hiddens, device)
l_sum, n, start = 0.0, 0, time.time()
data_iter = data_iter_fn(corpus_indices, batch_size, num_steps, device)
for X, Y in data_iter:
if is_random_iter:
state = init_rnn_state(batch_size, num_hiddens, device)
else:
for s in state:
s.detach_()
inputs = to_onehot(X, vocab_size)
(outputs, state) = rnn(inputs, state, params)
outputs = torch.cat(outputs, dim=0)
y = torch.transpose(Y, 0, 1).contiguous().view(-1)
l = loss(outputs, y.long())
if params[0].grad is not None:
for param