【NLP学习记录】Embedding和EmbeddingBag

Embedding与EmbeddingBag详解

●🍨 本文为🔗365天深度学习训练营 中的学习记录博客
●🍖 原作者:K同学啊 | 接辅导、项目定制
●🚀 文章来源:K同学的学习圈子

1、Embedding详解

Embedding是Pytorch中最基本的词嵌入操作,TensorFlow中也有相同的函数,功能是一样的。Embedding是将每个离散的词汇映射到一个低维的连续向量空间中,并且保持了词汇直接的语义关系。在Pytorch中,Embedding的输入是一个整数张量,每个整数都代表着一个词汇的索引,输出是一个浮点型的张量,每个浮点数都代表着对应词汇的词嵌入向量。

嵌入层使用随机权重初始化,并将学习数据集中的所有词嵌入。它是一个灵活的层,可以以各种方式使用,如:

  • 作为深度学习模型的一部分,其中嵌入与模型本身一起被学习。
  • 用于加载训练好的词嵌入模型
    嵌入层被定义为网络的第一个隐藏层。
    函数原型:
torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, _freeze=False, device=None, dtype=None)

常用参数:

num_embeddings: #词汇表大小,最大整数index + 1
embedding_dim:  #词向量的维度

简单示例,用Embedding将两个句子转换为词嵌入向量

import torch
import torch.nn as nn

vocab_size = 12 #词汇表大小
embedding_dim  = 4 #嵌入向量的维度

#创建一个embedding层
embedding = nn.Embedding(vocab_size, embedding_dim)

#假设我们有一个包含两个单词索引的输入序列
input_sequence1 = torch.tensor([1,5,8], dtype = torch.long)
input_sequence2 = torch.tensor([2,4], dtype = torch.long)

#使用Embedding层将输入序列转换为词嵌入
embedded_sequence1 = embedding(input_sequence1)
embedded_sequence2 = embedding(input_sequence2)

print(embedded_sequence1)
print(embedded_sequence2)

输出:
在这里插入图片描述
上例中,我们定义了简单的词嵌入模型

  1. 将大小为12 的词汇表中的每个词映射到了一个4维的向量空间中。
  2. 输入了两个句子,分别是[1, 5, 8]和[2, 4],每个数字代表着词汇表中的一个词汇的索引。
  3. 将这两个句子通过Embedding转换为词嵌入向量,并输出结果。
  4. 结果中,每个句子中的每个词汇都被映射成了4维的向量。

2. EmbeddingBag详解

EmbeddingBag是在Embedding基础上进一步优化的工具。主要优化点在于:它可以直接处理不定长的句子,并可计算句子中所有词汇的词嵌入向量的均值或总和。前者可以简化使用,后者则可即时评估向量生成效果。
在Pytorch中,EmbeddingBag的输入是一个整数张量和一个偏移量张量,每个整数都代表着一个词汇的索引,偏移量则表示句子中每个词汇的位置,输出是一个浮点型的张量。每个浮点数都代表这对应句子的词嵌入向量的均值或总和。

示例:用EmbeddingBag将两个句子转换为词嵌入向量并计算它们的均值。

import torch
import torch.nn as nn

vocab_size = 12 #词汇表大小
embedding_dim = 4 #嵌入向量维度

#创建一个EmbeddingBag层
embedding_bag = nn.EmbeddingBag(vocab_size, embedding_dim, mode = 'mean')

#假设我们有两个不同长度的输入序列
input_sequence1 = torch.tensor([1, 5, 8], dtype = torch.long)
input_sequence2 = torch.tensor([2, 4], dtype = torch.long)

#将两个输入序列拼接在一起,并创建一个偏移张量
input_sequences = torch.cat([input_sequence1, input_sequence2])
offsets = torch.tensor([0,len(input_sequence1)], dtype = torch.long)

#使用EmbeddingBag层计算序列汇总(这里使用平均值)
embedded_bag = embedding_bag(input_sequences, offsets)

print(embedded_bag)

输出:
在这里插入图片描述
在该示例中,我们的模型构建步骤如下:

  1. 定义一个大小为12的词汇表,并将每个词汇映射到一个4维的向量空间中。
  2. 输入两个句子,分别为[1, 5, 8]和[2, 4],每个数字代表词汇表中的一个词汇的索引。
  3. 通过EmbeddingBag将每个句子中的每个词汇转换为词嵌入向量,并计算它们的均值。
  4. 结果表明,每个句子的词嵌入向量的均值都是一个4维的向量。

EmbeddingBag层中的mode参数用于指定如何对每个序列中的嵌入向量进行汇总。常用模式主要有3种:sum、mean、max。模式选择主要取决于具体的任务和数据集。

  • 文本分类任务中,通常使用mean模式,因为它可以捕捉到每个序列的平均嵌入,反映出序列的整体含义。
  • 序列标注任务中,通常使用sum模式,因为它可以捕捉到每个序列的所有信息,不会丢失任何关键信息。

3. 任务补充

任务要求:
加载
one-hot篇附件中的.txt文件,并使用EmbeddingBag和Embedding完成词嵌入

代码:

# ------文本处理,将句子转换为整数序列 ------
import torch
import jieba

# 确认打开的文本文件路径和文件名
file_name = "D:\Personal Data\Learning Data\DL Learning Data\Embedding.txt"

# 从文件中读取文本行,并替代预定义的Sentences
with open(file_name, "r", encoding = "utf-8") as file:
    Context     = file.read()
    Sentences   = Context.split()
    print("==== 文本分句: ====\n", Sentences)     # 打印核对结果

# 使用jieba.cut()函数逐句进行分词,结果输出为一个列表
tokenized_texts = [list(jieba.lcut(sentence)) for sentence in Sentences]
print("==== 分词结果: ====\n", tokenized_texts)   # 打印核对结果

# 构建词汇表
word_index = {}
index_word = {}
for i, word in enumerate(set([word for text in tokenized_texts for word in text])):
    word_index[word] = i
    index_word[i]    = word

print("==== 词汇表: ====\n", word_index)          # 打印核对结果

# 将文本转化为整数序列
sequences = [[word_index[word] for word in text] for text in tokenized_texts]
print("==== 文本序列: ====\n",sequences)          # 打印核对结果

# 获取词汇表大小, 并+1
vocab_size      = len(word_index) + 1

输出:

==== 文本分句: ====
 ['比较直观的编码方式是采用上面提到的字典序列。例如,对于一个有三个类别的问题,可以用1、2和3分别表示这三个类别。但是,这种编码方式存在一个问题,就是模型可能会错误地认为不同类别之间存在一些顺序或距离关系,而实际上这些关系可能是不存在的或者不具有实际意义的。', '为了避免这种问题,引入了one-hot编码(也称独热编码)。one-hot编码的基本思想是将每个类别映射到一个向量,其中只有一个元素的值为1,其余元素的值为0。这样,每个类别之间就是相互独立的,不存在顺序或距离关系。例如,对于三个类别的情况,可以使用如下的one-hot编码:']
Loading model cost 0.753 seconds.
Prefix dict has been built successfully.
==== 分词结果: ====
 [['比较', '直观', '的', '编码方式', '是', '采用', '上面', '提到', '的', '字典', '序列', '。', '例如', ',', '对于', '一个', '有', '三个', '类别', '的', '问题', ',', '可以', '用', '1', '、', '2', '和', '3', '分别', '表示', '这', '三个', '类别', '。', '但是', ',', '这种', '编码方式', '存在', '一个', '问题', ',', '就是', '模型', '可能', '会', '错误', '地', '认为', '不同', '类别', '之间', '存在', '一些', '顺序', '或', '距离', '关系', ',', '而', '实际上', '这些', '关系', '可能', '是', '不', '存在', '的', '或者', '不', '具有', '实际意义', '的', '。'], ['为了', '避免', '这种', '问题', ',', '引入', '了', 'one', '-', 'hot', '编码', '(', '也', '称', '独热', '编码', ')', '。', 'one', '-', 'hot', '编码', '的', '基本', '思想', '是', '将', '每个', '类别', '映射', '到', '一个', '向量', ',', '其中', '只有', '一个', '元素', '的', '值', '为', '1', ',', '其余', '元素', '的', '值', '为', '0', '。', '这样', ',', '每个', '类别', '之间', '就是', '相互', '独立', '的', ',', '不', '存在', '顺序', '或', '距离', '关系', '。', '例如', ',', '对于', '三个', '类别', '的', '情况', ',', '可以', '使用', '如下', '的', 'one', '-', 'hot', '编码', ':']]
==== 词汇表: ====
 {'具有': 0, '就是': 1, '元素': 2, '引入': 3, '如下': 4, '思想': 5, '但是': 6, '模型': 7, '0': 8, '、': 9, '对于': 10, '为了': 11, '独立': 12, '其余': 13, '(': 14, '提到': 15, '这样': 16, '三个': 17, '采用': 18, '其中': 19, '表示': 20, '使用': 21, '到': 22, '存在': 23, '和': 24, ',': 25, '一些': 26, '这种': 27, '有': 28, '向量': 29, '例如': 30, '字典': 31, '编码': 32, '或': 33, '会': 34, 'hot': 35, '映射': 36, '比较': 37, '3': 38, '可以': 39, '。': 40, '了': 41, '序列': 42, '将': 43, '情况': 44, '2': 45, '是': 46, '或者': 47, '上面': 48, '这': 49, '编码方式': 50, '用': 51, '避免': 52, '实际意义': 53, '直观': 54, ')': 55, '实际上': 56, '值': 57, '这些': 58, '-': 59, '分别': 60, '而': 61, '相互': 62, '之间': 63, '也': 64, '只有': 65, 'one': 66, '认为': 67, '1': 68, '距离': 69, '问题': 70, '一个': 71, '可能': 72, '独热': 73, '称': 74, '类别': 75, '的': 76, '顺序': 77, '基本': 78, '不同': 79, '关系': 80, ':': 81, '地': 82, '错误': 83, '每个': 84, '不': 85, '为': 86}
==== 文本序列: ====
 [[37, 54, 76, 50, 46, 18, 48, 15, 76, 31, 42, 40, 30, 25, 10, 71, 28, 17, 75, 76, 70, 25, 39, 51, 68, 9, 45, 24, 38, 60, 20, 49, 17, 75, 40, 6, 25, 27, 50, 23, 71, 70, 25, 1, 7, 72, 34, 83, 82, 67, 79, 75, 63, 23, 26, 77, 33, 69, 80, 25, 61, 56, 58, 80, 72, 46, 85, 23, 76, 47, 85, 0, 53, 76, 40], [11, 52, 27, 70, 25, 3, 41, 66, 59, 35, 32, 14, 64, 74, 73, 32, 55, 40, 66, 59, 35, 32, 76, 78, 5, 46, 43, 84, 75, 36, 22, 71, 29, 25, 19, 65, 71, 2, 76, 57, 86, 68, 25, 13, 2, 76, 57, 86, 8, 40, 16, 25, 84, 75, 63, 1, 62, 12, 76, 25, 85, 23, 77, 33, 69, 80, 40, 30, 25, 10, 17, 75, 76, 44, 25, 39, 21, 4, 76, 66, 59, 35, 32, 81]]

使用EmbeddingBag进行词嵌入:

# 创建一个EmbeddingBag层
embedding_dim   = 100               # 定义嵌入向量的维度
embedding_bag   = torch.nn.EmbeddingBag(vocab_size, embedding_dim, mode = "mean")

# 将多个输入序列拼接在一起,并创建一个偏移量张量
# 首先需要创建空张量和空列表
input   = torch.tensor([], dtype = torch.long)
offset  = []
# 逐句处理,进行张量拼接、向列表中添加偏移量数值
for sequence in sequences:
    offset.append(len(input))
    input   = torch.cat([input, torch.tensor(sequence, dtype= torch.long)])

# 将列表形式的偏移量转换为张量形式,用于embedding_bag()函数的输入
offset = torch.tensor(offset, dtype = torch.long)

# 检查序列张量拼接和索引生成结果
print("-"*80)    
print(offset)
print(input)
print("-"*80)

# 使用Embedding层将输入序列转换为词嵌入
embedded_bag = embedding_bag(input, offset)

# 打印输出结果
print("词嵌入结果: \n", embedded_bag)

输出:

tensor([ 0, 75])
tensor([37, 54, 76, 50, 46, 18, 48, 15, 76, 31, 42, 40, 30, 25, 10, 71, 28, 17,
        75, 76, 70, 25, 39, 51, 68,  9, 45, 24, 38, 60, 20, 49, 17, 75, 40,  6,
        25, 27, 50, 23, 71, 70, 25,  1,  7, 72, 34, 83, 82, 67, 79, 75, 63, 23,
        26, 77, 33, 69, 80, 25, 61, 56, 58, 80, 72, 46, 85, 23, 76, 47, 85,  0,
        53, 76, 40, 11, 52, 27, 70, 25,  3, 41, 66, 59, 35, 32, 14, 64, 74, 73,
        32, 55, 40, 66, 59, 35, 32, 76, 78,  5, 46, 43, 84, 75, 36, 22, 71, 29,
        25, 19, 65, 71,  2, 76, 57, 86, 68, 25, 13,  2, 76, 57, 86,  8, 40, 16,
        25, 84, 75, 63,  1, 62, 12, 76, 25, 85, 23, 77, 33, 69, 80, 40, 30, 25,
        10, 17, 75, 76, 44, 25, 39, 21,  4, 76, 66, 59, 35, 32, 81])
--------------------------------------------------------------------------------
词嵌入结果: 
 tensor([[ 5.5692e-02, -7.1316e-02, -6.6372e-02, -1.6054e-02, -1.5890e-02,
          9.3022e-03, -4.8811e-02, -4.7034e-02, -8.3843e-03, -1.9857e-02,
          1.1352e-02,  1.4714e-01,  1.2782e-01,  1.2540e-01,  2.2769e-01,
         -2.1690e-01,  1.2728e-02, -1.9718e-01,  4.8604e-02, -8.8129e-02,
          1.4818e-01, -3.2952e-01, -3.2805e-02, -1.6356e-01,  1.1112e-01,
         -1.0095e-01, -5.1578e-02, -1.1523e-01,  3.2936e-01, -3.6964e-01,
         -8.3445e-02, -6.9567e-02,  1.1665e-01,  1.6558e-01,  2.6067e-01,
         -1.4318e-01,  6.1249e-02, -3.4959e-02, -4.9525e-02,  5.4743e-02,
         -3.6878e-02,  6.7813e-02,  7.3439e-02, -6.0280e-03,  7.6804e-02,
          3.0789e-02,  1.6987e-01, -6.2407e-02,  8.7294e-02,  1.1892e-01,
         -1.7377e-01, -2.3485e-04,  4.0972e-02,  1.6278e-02,  1.0198e-01,
         -1.4946e-01, -2.4754e-01, -1.2399e-01,  4.3227e-02,  4.9916e-02,
         -2.0984e-01,  1.5504e-01, -1.7622e-01,  1.1868e-01,  1.7071e-01,
         -2.5039e-02,  1.0324e-01, -1.2662e-02,  2.5191e-01,  8.0460e-02,
          7.6614e-02, -2.7530e-01,  6.2472e-02,  1.6579e-01,  8.8133e-02,
          1.3551e-01, -2.7536e-01,  3.3397e-02, -1.5716e-01, -2.0973e-01,
         -1.2795e-01, -2.1313e-01, -2.1758e-02,  1.2416e-01,  1.7992e-01,
         -3.7501e-01, -4.4248e-02,  1.2105e-01, -3.8175e-02,  2.2732e-01,
         -6.5679e-02,  8.8733e-02, -1.3111e-01,  5.4078e-02,  4.8781e-03,
          2.1772e-01,  1.9874e-01,  2.0958e-03, -9.0118e-02, -2.5515e-01],
        [ 1.3323e-01,  3.4331e-02, -8.2839e-02,  1.2361e-01, -2.8169e-01,
         -4.4598e-02, -3.2946e-02, -1.4866e-01,  1.6180e-01, -2.7819e-01,
         -6.4126e-02, -4.7323e-02,  1.3915e-02,  1.6032e-01,  1.3364e-01,
         -7.4780e-02, -1.5139e-01, -3.3980e-01,  1.9030e-01,  1.9205e-02,
          2.0883e-01, -2.5076e-01, -5.6575e-02,  6.3742e-02,  1.1623e-01,
          4.9935e-03,  1.2527e-01, -4.2433e-01,  7.7243e-02, -1.4197e-01,
         -6.0625e-02, -1.9400e-01,  2.1380e-02,  9.1898e-02,  2.0745e-01,
         -3.2166e-01,  6.8514e-02, -1.3532e-02, -1.5256e-01,  5.6108e-02,
         -1.3043e-01,  4.6693e-02,  4.1241e-02,  2.1686e-01,  2.1685e-01,
          1.0330e-01,  1.9492e-01,  4.1955e-02,  1.1217e-01,  2.3970e-01,
         -8.2919e-02, -2.8210e-01,  1.9147e-01, -4.3560e-02, -3.4606e-01,
          6.1696e-02, -2.7092e-01, -2.3062e-01,  4.1380e-02,  1.0824e-01,
         -2.8622e-01,  3.2143e-03, -1.1270e-01,  4.5128e-02,  1.0355e-01,
          9.4424e-03, -1.0250e-02, -3.4805e-01,  2.8352e-01,  1.2028e-01,
         -4.0606e-02, -6.6366e-02,  8.5532e-02,  3.8679e-01,  2.7721e-01,
          2.1182e-01, -3.1888e-01, -5.2964e-02,  7.4347e-03, -4.4626e-02,
         -5.2423e-02, -2.4242e-01,  3.4623e-02, -1.0850e-01,  6.3444e-02,
         -1.3703e-01,  1.8642e-01,  1.6533e-01, -1.4030e-01,  1.9591e-01,
         -3.6116e-02,  2.3281e-01, -8.3522e-03,  1.1427e-01, -4.5963e-03,
          1.5077e-01, -2.7178e-02, -3.9486e-02, -1.0071e-01, -1.1353e-01]],
       grad_fn=<EmbeddingBagBackward0>)
  • 17
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值