用 TensorFlow 在 Transformers 上生成字幕的注意机制的实现

本文介绍了如何使用Tensorflow实现Transformer模型,特别是针对图像字幕任务的注意力机制。通过对比Transformer与传统的注意力模型,展示了Transformer在序列到序列建模中的优势,并提供了详细的实现步骤,包括数据预处理、模型定义、训练和评估。
摘要由CSDN通过智能技术生成

总览

  • 了解最先进的变压器模型。

  • 了解我们如何使用Tensorflow在已经看到的图像字幕问题上实现变形金刚

  • 比较《变形金刚》与注意力模型的结果。

介绍

我们已经看到,注意力机制已成为各种任务(例如图像字幕)中引人注目的序列建模和转导模型的组成部分,从而允许对依赖项进行建模,而无需考虑它们在输入或输出序列中的距离。

Transformer是一种避免重复发生的模型体系结构,而是完全依赖于注意力机制来绘制输入和输出之间的全局依存关系。Transformer体系结构允许更多并行化,并可以达到翻译质量方面的最新水平。

在本文中,让我们看看如何使用TensorFlow来实现用变形金刚生成字幕的注意力机制。

开始之前的先决条件:

  • Python编程

  • Tensorflow和Keras

  • RNN和LSTM

  • 转移学习

  • 编码器和解码器架构

  • 深度学习的要点–注意序列到序列建模

我建议您在阅读本文前可以参考下面资料:

一个动手教程来学习Python中图像标题生成的注意机制

https://www.analyticsvidhya.com/blog/2020/11/attention-mechanism-for-caption-generation/

目录 

一、Transformer 架构 

二、使用Tensorflow的变压器字幕生成注意机制的实现 

2.1、导入所需的库 

2.2、数据加载和预处理 

2.3、模型定义 

2.4、位置编码 

2.5、多头注意力 

2.6、编码器-解码器层 

2.7、Transformer 

2.8、模型超参数 

2.9、模型训练 

2.10、BLEU评估 

2.11、比较方式 

三、下一步是什么?

四、尾注

Transformer 架构

Transformer 网络采用类似于RNN的编解码器架构。主要区别在于,转换器可以并行接收输入的句子/顺序,即没有与输入相关的时间步长,并且句子中的所有单词都可以同时传递。

让我们从了解变压器的输入开始。

考虑一下英语到德语的翻译。我们将整个英语句子输入到输入嵌入中。可以将输入嵌入层视为空间中的一个点,其中含义相似的单词在物理上彼此更接近,即,每个单词映射到具有连续值的矢量来表示该单词。

现在的问题是,不同句子中的相同单词可能具有不同的含义,这就是位置编码输入的地方。由于转换器不包含递归和卷积,因此为了使模型能够利用序列的顺序,它必须利用一些有关序列中单词相对或绝对位置的信息。这个想法是使用固定或学习的权重,该权重对与句子中标记的特定位置有关的信息进行编码。

类似地,将目标德语单词输入到输出嵌入中,并将其位置编码矢量传递到解码器块中。

编码器块具有两个子层。第一个是多头自我关注机制,第二个是简单的位置完全连接的前馈网络。对于每个单词,我们可以生成一个注意力向量,该向量捕获句子中单词之间的上下文关系。编码器中的多头注意力会应用一种称为自我注意力的特定注意力机制。自注意力允许模型将输入中的每个单词与其他单词相关联。

除了每个编码器层中的两个子层之外,解码器还插入第三子层,该第三子层对编码器堆栈的输出执行多头关注。与编码器类似,我们在每个子层周围采用残余连接,然后进行层归一化。来自编码器的德语单词的注意力向量和英语句子的注意力向量被传递到第二多头注意力。

该注意块将确定每个单词向量彼此之间的关联程度。这是英语到德语单词映射的地方。解码器以充当分类器的线性层和softmax来封闭,以获取单词概率。

现在,您已基本了解了转换器的工作方式,让我们看看如何使用Tensorflow将其实现用于图像字幕任务,并将我们的结果与其他方法进行比较。

使用TensorFlow在Transformers 上生成字幕的注意机制的实现

步骤1:导入所需的库

在这里,我们将利用Tensorflow创建模型并对其进行训练。大部分代码归功于TensorFlow教程。如果您想要GPU进行训练,则可以使用Google Colab或Kaggle笔记本。

import string
import numpy as np
import pandas as pd
from numpy import array
from PIL import Image
import pickle
 
import matplotlib.pyplot as plt
import sys, time, os, warnings
warnings.filterwarnings("ignore")
import re
 
import keras
import tensorflow as tf
from tqdm import tqdm
from nltk.translate.bleu_score import sentence_bleu
 
from keras.preprocessing.sequence import pad_sequences
from keras.utils import to_categorical
from keras.utils import plot_model
from keras.models import Model
from keras.layers import Input
from keras.layers import Dense, BatchNormalization
from keras.layers import LSTM
from keras.layers import Embedding
from keras.layers import Dropout
from keras.layers.merge import add
from keras.callbacks import ModelCheckpoint
from keras.preprocessing.image import load_img, img_to_array
from keras.preprocessing.text import Tokenizer
 
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split

步骤2:数据加载和预处理

定义图像和字幕路径,并检查数据集中总共有多少图像。

image_path = "/content/gdrive/My Drive/FLICKR8K/Flicker8k_Dataset"
dir_Flickr_text = "/content/gdrive/My Drive/FLICKR8K/Flickr8k_text/Flickr8k.token.txt"
jpgs = os.listdir(image_path)
 
print("Total Images in Dataset = {}".format(len(jpgs)))

输出如下:

我们创建一个数据框来存储图像ID和标题,以便于使用。

file = open(dir_Flickr_text,'r')
text = file.read()
file.close()
 
datatxt = []
for line in text.split('\n'):
   col = line.split('\t')
   if len(col) == 1:
       continue
   w = col[0].split("#")
   datatxt.append(w + [col[1].lower()])
 
data = pd.DataFrame(datatxt,columns=["filename","index","caption"])
data = data.reindex(columns =['index','filename','caption'])
data = data[data.filename != '2258277193_586949ec62.jpg.1']
uni_filenames = np.unique(data.filename.values)
 
data.head()

输出如下:

接下来,让我们可视化一些图片及其5个标题:

npic = 5
npix = 224
target_size = (npix,npix,3)
count = 1
 
fig = plt.figure(figsize=(10,20))
for jpgfnm in uni_filenames[10:14]:
   filename = image_path + '/' + jpgfnm
   captions = list(data["caption"].loc[data["filename"]==jpgfnm].values)
   image_load = load_img(filename, target_size=target_size)
   ax = fig.add_subplot(npic,2,count,xticks=[],yticks=[])
   ax.imshow(image_load)
   count += 1
 
   ax = fig.add_subplot(npic,2,count)
   plt.axis('off')
   ax.plot()
   ax.set_xlim(0,1)
   ax.set_ylim(0,len(captions))
   for i, caption in enumerate(captions):
       ax.text(0,i,caption,fontsize=20)
   count += 1
plt.show()

输出如下:

接下来,让我们看看我们当前的词汇量是多少:

vocabulary = []
for txt in data.caption.values:
   vocabulary.extend(txt.split())
print('Vocabulary Size: %d' % len(set(vocabulary)))

输出如下:

接下来执行一些文本清理,例如删除标点符号,单个字符和数字值:

def remove_punctuation(text_original):
   text_no_punctuation = text_original.translate(string.punctuat
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值