Python 中图像标题生成的注意力机制实战

本文介绍了在Python中使用注意力机制生成图像标题的方法,特别是Bahdanau注意力,通过VGG16提取图像特征,LSTM解码器结合注意力模型生成描述。详细讲解了数据预处理、模型定义、训练过程及BLEU评估。
摘要由CSDN通过智能技术生成

总览

了解图像字幕生成的注意力机制 实现注意力机制以在python中生成字幕

介绍

注意机制是人类所具有的复杂的认知能力。当人们收到信息时,他们可以有意识地选择一些主要信息,而忽略其他次要信息。

这种自我选择的能力称为注意力。注意机制使神经网络能够专注于其输入子集以选择特定特征。

近年来,神经网络推动了图像字幕的巨大发展。研究人员正在为计算机视觉和序列到序列建模系统寻找更具挑战性的应用程序。他们试图用人类的术语描述世界。之前我们看到了通过Merge架构进行图像标题处理的过程,今天,我们将探讨一种更为复杂而精致的设计来解决此问题。

注意机制已成为深度学习社区中从业者的首选方法。它最初是在使用Seq2Seq模型的神经机器翻译的背景下设计的,但今天我们将看看它在图像字幕中的实现。

注意机制不是将整个图像压缩为静态表示,而是使显着特征在需要时动态地走在最前列。当图像中有很多杂波时,这一点尤其重要。

让我们举个例子来更好地理解:

我们的目标是生成一个标题,例如“两只白狗在雪地上奔跑”。为此,我们将看到如何实现一种称为Bahdanau的注意力或本地注意力的特定类型的注意力机制。

通过这种方式,我们可以看到模型在生成标题时将焦点放在图像的哪些部分。此实现将需要深度学习的强大背景。

目录

1、问题陈述的处理

2、了解数据集

3、实现

3.1、导入所需的库

3.2、数据加载和预处理

3.3、模型定义

3.4、模型训练

3.5、贪婪搜索和BLEU评估

4、下一步是什么?

5、尾注

问题陈述的处理

编码器-解码器图像字幕系统将使用将产生隐藏状态的预训练卷积神经网络对图像进行编码。然后,它将使用LSTM解码此隐藏状态并生成标题。

对于每个序列元素,将先前元素的输出与新序列数据结合起来用作输入。这为RNN网络提供了一种记忆,可能使字幕更具信息性和上下文感知能力。

但是RNN的训练和评估在计算上往往很昂贵,因此在实践中,内存只限于少数几个元素。注意模型可以通过从输入图像中选择最相关的元素来帮助解决此问题。使用Attention机制,首先将图像分为n个部分,然后我们计算每个图像的图像表示形式。当RNN生成新单词时,注意机制将注意力集中在图像的相关部分上,因此解码器仅使用特定的图片的一部分。

在Bahdanau或本地关注中,关注仅放在少数几个来源位置。由于全球关注集中于所有目标词的所有来源方词,因此在计算上非常昂贵。为了克服这种缺陷,本地注意力选择只关注每个目标词的编码器隐藏状态的一小部分。

局部注意力首先找到对齐位置,然后在其位置所在的左右窗口中计算注意力权重,最后对上下文向量进行加权。局部注意的主要优点是减少了注意机制计算的成本。

在计算中,本地注意力不是考虑源语言端的所有单词,而是根据预测函数预测在当前解码时要对齐的源语言端的位置,然后在上下文窗口中导航, 仅考虑窗口中的单词。

Bahdanau注意的设计

编码器和解码器的所有隐藏状态用于生成上下文向量。注意机制将输入和输出序列与前馈网络参数化的比对得分进行比对。它有助于注意源序列中最相关的信息。该模型基于与源位置和先前生成的目标词关联的上下文向量来预测目标词。

为了参考原始字幕评估字幕,我们使用一种称为BLEU的评估方法。它是使用最广泛的评估指标。它用于分析要评估的翻译语句与参考翻译语句之间n-gram的相关性。

在本文中,多个图像等效于翻译中的多个源语言句子。BLEU的优点是考虑更长的匹配信息,它认为的粒度是n元语法字而不是单词。BLEU的缺点是无论匹配哪种n-gram,都将被视为相同。

我希望这使您对我们如何处理此问题陈述有所了解。让我们深入研究实施!

了解数据集

我使用了Flickr8k数据集,其中每个图像都与五个不同的标题相关联,这些标题描述了所收集的图像中描述的实体和事件。

Flickr8k体积小巧,可以使用CPU在低端笔记本电脑/台式机上轻松进行培训,因此是一个很好的入门数据集。

我们的数据集结构如下:

让我们实现字幕生成的注意力机制!

步骤1:导入所需的库

在这里,我们将利用Tensorflow创建模型并对其进行训练。大部分代码归功于TensorFlow教程。如果您想要GPU进行训练,则可以使用Google Colab或Kaggle笔记本。

import string
import numpy as np
import pandas as pd
from numpy import array
from pickle import load
 
from PIL import Image
import pickle
from collections import Counter
import matplotlib.pyplot as plt
 
import sys, time, os, warnings
warnings.filterwarnings("ignore")
import re
 
import keras
import tensorflow as tf
from tqdm import tqdm
from nltk.translate.bleu_score import sentence_bleu
 
from keras.preprocessing.sequence import pad_sequences
from keras.utils import to_categorical
from keras.utils import plot_model
from keras.models import Model
from keras.layers import Input
from keras.layers import Dense, BatchNormalization
from keras.layers import LSTM
from keras.layers import Embedding
from keras.layers import Dropout
from keras.layers.merge import add
from keras.callbacks import ModelCheckpoint
from keras.preprocessing.image import load_img, img_to_array
from keras.preprocessing.text import Tokenizer
from keras.applications.vgg16 import VGG16, preprocess_input
 
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle

步骤2:数据加载和预处理

定义图像和字幕路径,并检查数据集中总共有多少图像。

image_path = "/content/gdrive/My Drive/FLICKR8K/Flicker8k_Dataset"
dir_Flickr_text = "/content/gdrive/My Drive/FLICKR8K/Flickr8k_text/Flickr8k.token.txt"
jpgs = os.listdir(image_path)
 
print("Total Images in Dataset = {}".format(len(jpgs)))

输出如下:

我们创建一个数据框来存储图像ID和标题,以便于使用。

file = open(dir_Flickr_text,'r')
text = file.read()
file.close()
 
datatxt = []
for line in text.split('\n'):
   col = line.split('\t')
   if len(col) == 1:
       continue
   w = col[0].split("#")
   datatxt.append(w + [col[1].lower()])
 
data = pd.DataFrame(datatxt,columns=["filename","index","caption"])
data = data.reindex(columns =['index','filename','caption'])
data = data[data.filename != '2258277193_586949ec62.jpg.1']
uni_filenames = np.unique(data.filename.values)
 
data.head()

输出如下:

接下来,让我们可视化一些图片及其5个标题:

npic = 5
npix = 224
target_size = (npix,npix,3)
count = 1
 
fig = plt.figure(figsize=(10,20))
for jpgfnm in uni_filenames[10:14]:
   filename = image_path + '/' + jpgfnm
   captions = list(data["caption"].loc[data["filename"]==jpgfnm].values)
   image_load = load_img(filename, target_size=target_size)
   ax = fig.add_subplot(npic,2,count,xticks=[],yticks=[])
   ax.imshow(image_load)
   count += 1
 
   ax = fig.add_subplot(npic,2,count)
   plt.axis('off')
   ax
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值