attention机制学习

attention机制通过模仿人类注意力,解决LSTM/RNN长序列处理难题。在encoder-decoder框架中,它允许decoder根据输入序列不同部分赋予不同权重,提高处理效率。文章通过实例解释了如何使用attention进行新闻标题文本分类,包括数据处理、模型构建和训练测试过程。
摘要由CSDN通过智能技术生成

attention机制概述

attention机制是模仿人类注意力而提出的一种解决问题的办法,简单地说就是从大量信息中快速筛选出高价值信息。主要用于解决LSTM/RNN模型输入序列较长的时候很难获得最终合理的向量表示问题,做法是保留LSTM的中间结果,用新的模型对其进行学习,并将其与输出进行关联,从而达到信息筛选的目的。

encoder+decoder背景

encoder+decoder,中文名字是编码器和解码器,应用于seq2seq问题,其实就是固定长度的输入转化为固定长度输出。其中encoder和decoder可以采用的模型包括CNN/RNN/BiRNN/GRU/LSTM等,可以根据需要自己的喜好自由组合。
在这里插入图片描述
encoder过程将输入的句子转换为语义中间件,decoder过程根据语义中间件和之前的单词输出,依次输出最有可能的单词组成句子。
在这里插入图片描述
问题就是当输入长度非常长的时候,这个时候产生的语义中间件效果非常的不好,需要调整。

attention模型

attention模型用于解码过程中,它改变了传统decoder对每一个输入都赋予相同向量的缺点,而是根据单词的不同赋予不同的权重。在encoder过程中,输出不再是一个固定长度的中间语义,而是一个由不同长度向量构成的序列,decoder过程根据这个序列子集进行进一步处理。

在这里插入图片描述
假设输入为一句英文的话:Tom chase Jerry,最终的结果应该是逐步输出 “汤姆”,“追逐”,“杰瑞”。
如果用传统encoder-decoder模型,那么在翻译Jerry时,所有输入单词对翻译的影响都是相同的,但显然Jerry的贡献度应该更高。
引入attention后,每个单词都会有一个权重:(Tom,0.3)(Chase,0.2) (Jerry,0.5),现在的关键是权重应如何计算。
在这里插入图片描述
从图上可以看出来,加了attention机制以后,encoder层的每一步输出都会和当前的输出进行联立计算(wx+b形式),最后用softmx函数生成概率值。
概率值出来了,最后的结果就是一个加权和的形式。
在这里插入图片描述
基本上所有的attention都采用了这个原理,只不过算权重的函数形式可能会有所不同,但想法相同。

使用attention机制实现新闻标题文本分类

本次实验数据集来自 头条爬取的新闻标题数据集
训练数据集被处理为“类别-标题”的格式

读取数据并进行分词处理与生成词典:

from pyhanlp import HanLP
import numpy as np
from tqdm import tqdm

# 读取原始数据集分词预处理 并保存词典
def read_toutiao_dataset(data_path, save_vocab_path):
    with open(data_path, "r", encoding="utf8") as fo:
        all_lines = fo.readlines()
    datas, labels = [], []
    word_vocabs = {
   }
    for line in tqdm(all_lines):
        content_words = []
        category, content = line.strip().split("_!_")
        for term in HanLP.segment(content):
            if term.word not in word_vocabs:
                word_vocabs[term.word] = len(word_vocabs)+1
            content_words.append(term.word)
        datas.append(content_words)
        labels.append(category)
    with open(save_vocab_path, "w", encoding="utf8") as fw:
        for word, index in word_vocabs.items():
            fw.write(word+"\n")
    return datas, labels

读取词典,生成索引并把文本序列变成词编号序列:

def read_word_vocabs(save_vocab_path, special_words):
    with open(save_vocab_path, "r", encoding="utf8") as fo:
        word_vocabs = [word.strip() for word in fo]
    word_vocabs = special_words + word_vocabs
    idx2vocab = {
   idx: char for idx, char in enumerate(word_vocabs)} # 索引-词对应
    vocab2idx = {
   char: idx for idx, char in i
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值