【深度之眼cs231n第七期】笔记(二十五)

本文深入分析了LSTM如何解决RNN的梯度消失问题,并讨论了一个过拟合的LSTM模型,该模型在验证集上生成的图片描述与实际不符。通过BLEU指标评估模型性能,提出达到0.3以上的BLEU unigram分数即可视为有效。同时,文中还介绍了sigmoid函数的应用,以及LSTM的单步和多步前向、反向传播计算。最后提到了CaptioningSolver的初始化过程。
摘要由CSDN通过智能技术生成

LSTM_Captioning.ipynb

原始RNN中求取梯度时,要多次乘以同一矩阵,这就容易导致梯度消失或梯度爆炸,而长短期记忆网络(LSTM)能解决这个问题。

过拟合的LSTM模型

这个过拟合的LSTM模型应该会得到小于0.5的损失。

    np.random.seed(231)
    # 每个epoch有50个数据
    small_data = load_coco_data(max_train=50)
    
    small_lstm_model = CaptioningRNN(
              cell_type='lstm',
              word_to_idx=data['word_to_idx'],
              input_dim=data['train_features'].shape[1],
              hidden_dim=512,
              wordvec_dim=256,
              dtype=np.float32,
            )
    # 每个epoch有50个数据,batch_size是25,所以每个epoch迭代2次
    # 乘以50个epoch,总共迭代100次
    small_lstm_solver = CaptioningSolver(small_lstm_model, small_data,
               update_rule='adam',
               num_epochs=50,
               batch_size=25,
               optim_config={
   
                 'learning_rate': 5e-3,
               },
               lr_decay=0.995,
               verbose=True, print_every=10,
             )
    small_lstm_solver.train()
    
    # 画出训练损失
    plt.plot(small_lstm_solver.loss_history)
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.title('Training loss history')
    plt.show()

在这里插入图片描述
使用训练好的模型生成图片描述,并和真实描述对比。

for split in ['train', 'val']:
    # 在训练集和验证集里分别选两张图片
    minibatch = sample_coco_minibatch(small_data, split=split, batch_size=2)
    gt_captions, features, urls = minibatch
    # 把真实描述从数字ID转换成单词
    gt_captions = decode_captions(gt_captions, data['idx_to_word'])
    # 使用训练好的模型生成图片描述
    sample_captions = small_lstm_model.sample(features)
    sample_captions = decode_captions(sample_captions, data['idx_to_word'])
    
    # 展示图片和描述(生成的和真实的)
    for gt_caption, sample_caption, url in zip(gt_captions, sample_captions, urls):
        plt.figure(figsize=(5, 0.5))
        # 本来应该直接展示图片的"plt.imshow(image_from_url(url))",但是展示不出来,只能打印链接
        print(url)
        plt.title('%s\n%s\nGT:%s' % (split, sample_caption, gt_caption))
        plt.axis('off')
        plt.show()

在这里插入图片描述
验证集图片生成的描述完全和图片不相关……
在这里插入图片描述
使用BLEU评价指标来评价模型(可以看看这篇文章),该评价指标在0到1之间,越接近1代表结果越好。

import nltk
def BLEU_score(gt_caption, sample_caption):
    """
    gt_caption: 真实的描述
    sample_caption: 生成的描述
    返回unigram BLEU得分
    """
    # 把描述分成一个个单词,开始、结束、未知token不用于计算得分
    reference = [x for x in gt_caption.split(' ') 
                 if ('<END>' not in x and '<START>' not in x and '<UNK>' not in x)]
    hypothesis = [x for x in sample_caption.split(' ') 
                  if ('<END>' not in x and '<START>' not in x and '<UNK>' not in x)
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值