RNN模型实现(二)

训练模型

训练模型有一套固定的流程,其中的细节只需要略微调整即可。

定义遍历数据集训练一次函数:

# 训练网络模型
def train_epoch(net, train_iter, loss, updater, device):
    # 每一轮训练函数
    # net:网络模型
    # train_iter:训练迭代器
    # loss:损失
    # updater:优化器
    # device:设备编号
    state = None
    loss_, numel_ = 0, 0
    for X,y in train_iter:
        if state == None:
            state = net.begin_state(batch_size=X.shape[0], device=device)
        else:
            for s in state:
                s.detach_()
        y = y.T.reshape(-1) # 将转置后的向量铺平成一维向量
        X,y = X.to(device),y.to(device)
        y_hat,state = net(X, state)
        l = loss(y_hat, y).mean()
        updater.zero_grad()
        l.backward()
        updater.step()
        loss_ += l*y.numel()
        numel_ += y.numel()
    return math.exp(loss_/numel_)

这里的损失选择交叉熵损失,函数的返回值为困惑度,值越接近1,模型效果越好,接下来定义训练函数:

def train(net, train_iter, updater, lr, num_epochs, vocab, device):
    loss = nn.CrossEntropyLoss()
    updater = torch.optim.SGD(net.params, lr=lr)
    for epoch in range(num_epochs):
        ppl = train_epoch(net, train_iter, loss, updater, device)
        if (epoch + 1) % 100 == 0:
            print(f'epoch: {epoch+1}  困惑度:{ppl:.2f}')
    print(f'训练完成,困惑度:{ppl:.2f}')
    print(predict_char('hello', 50, net, vocab, device))

写一个测试的例子:

loss = nn.CrossEntropyLoss()
lr = 1.2
updater = torch.optim.SGD(net.params, lr)
num_epochs = 500
train_epoch(net, train_iter, loss, updater, device)
train(net, train_iter, updater, lr, num_epochs, vocab, device)

输出的结果部分如下:

在这里插入图片描述

可以看到在训练迭代次数在120-130之间时,困惑度出现了小幅度的上升,而在随后出现了大幅度的上升,这种现象被称为“梯度爆炸”。为了缓解这种情况,我们引入了梯度裁剪进行处理:

# 对于训练过程中的梯度爆炸问题,使用梯度裁剪的方式缓解这种情况
def grad_clipping(net, theta):
    # 遍历存储梯度信息的参数
    params = net.params
    norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))
    if norm > theta:
        for param in params:
            param.grad[:] *= theta / norm

将梯度剪裁引入训练过程中后,输出的结果如下:

epoch: 100  困惑度:7.14
epoch: 200  困惑度:2.63
epoch: 300  困惑度:1.28
epoch: 400  困惑度:1.22
epoch: 500  困惑度:1.20
训练完成,困惑度:1.20
hellor ary time tove gotings onsthod and anetareoushing

可以看到训练完成后的困惑度达到了1.20,预测的结果虽然难以理解,但表现出一定的预测潜力。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值