【nlp入门实战】词嵌入embedding实现

词嵌入

词嵌入是一种用于自然语言处理 (NLP) 的技术,用于将单词表示为数字,以便计算机可以处理它们。通俗的讲就是,一种把文本转为数值输入到计算机中的方法。

Embedding和EmbeddingBag则是PyTorch中的用来处理文本数据中词嵌入(word embedding)的工具,它们将离散的词汇映射到低维的连续向量空间中,使得词汇之间的语义关系能够在向量空间中得到体现。

使用RGB颜色理解词嵌入

实际上对颜色的RGB表示法就属于一种典型的分布式表示:

对于颜色,我们可以把它拆成三个特征维度,用这三个维度的组合理论上可以表示任意一种颜色。同理,对于词,我们也可以把它拆成指定数量的特征维度,词表中的每一个词都可以用这些维度组合成的向量来表示,这个就是Word Embedding的含义。

当然,词跟颜色还是有很大的差别的——我们已经知道表示颜色的三个维度有明确对应的物理意义(即RGB),直接使用物理原理就可以知道某一个颜色对应的RGB是多少。但是对于词,我们无法给出每个维度所具备的可解释的意义,也无法直接求出一个词的词向量的值应该是多少。所以我们需要使用语料和模型来训练词向量——把嵌入矩阵当成模型参数的一部分,通过词与词间的共现或上下文关系来优化模型参数,最后得到的矩阵就是词表中所有词的词向量。

这里需要说明的是,有的初学者可能没绕过一个弯,就是“最初的词向量是怎么来的”——其实你只要知道最初的词向量是瞎JB填的就行了。嵌入矩阵最初的参数跟模型参数一样是随机初始化的,然后前向传播计算损失函数,反向传播求嵌入矩阵里各个参数的导数,再梯度下降更新,这个跟一般的模型训练都是一样的。等训练得差不多的时候,嵌入矩阵就是比较准确的词向量矩阵了。
在这里插入图片描述

代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import jieba

# 文本
with open('./data/任务文件.txt', "r", encoding="utf-8") as file:
    content = file.read()
    texts = content.split(' ')

# 使用结巴分词进行分词
tokenized_texts = [list(jieba.cut(text)) for text in texts]

print(tokenized_texts)

[[‘比较’, ‘直观’, ‘的’, ‘编码方式’, ‘是’, ‘采用’, ‘上面’, ‘提到’, ‘的’, ‘字典’, ‘序列’, ‘。’], [‘例如’, ‘,’, ‘对于’, ‘一个’, ‘有’, ‘三个’, ‘类别’, ‘的’, ‘问题’, ‘,’, ‘可以’, ‘用’, ‘1’, ‘、’, ‘2’, ‘和’, ‘3’, ‘分别’, ‘表示’, ‘这’, ‘三个’, ‘类别’, ‘。’], [‘但是’, ‘,’, ‘这种’, ‘编码方式’, ‘存在’, ‘一个’, ‘问题’, ‘,’, ‘就是’, ‘模型’, ‘可能’, ‘会’, ‘错误’, ‘地’, ‘认为’, ‘不同’, ‘类别’, ‘之间’, ‘存在’, ‘一些’, ‘顺序’, ‘或’, ‘距离’, ‘关系’, ‘,’, ‘而’, ‘实际上’, ‘这些’, ‘关系’, ‘可能’, ‘是’, ‘不’, ‘存在’, ‘的’, ‘或者’, ‘不’, ‘具有’, ‘实际意义’, ‘的’, ‘。’], [‘为了’, ‘避免’, ‘这种’, ‘问题’, ‘,’, ‘引入’, ‘了’, ‘one’, ‘-’, ‘hot’, ‘编码’, ‘(’, ‘也’, ‘称’, ‘独热’, ‘编码’, ‘)’, ‘。’], [‘one’, ‘-’, ‘hot’, ‘编码’, ‘的’, ‘基本’, ‘思想’, ‘是’, ‘将’, ‘每个’, ‘类别’, ‘映射’, ‘到’, ‘一个’, ‘向量’, ‘,’, ‘其中’, ‘只有’, ‘一个’, ‘元素’, ‘的’, ‘值’, ‘为’, ‘1’, ‘,’, ‘其余’, ‘元素’, ‘的’, ‘值’, ‘为’, ‘0’, ‘。’], [‘这样’, ‘,’, ‘每个’, ‘类别’, ‘之间’, ‘就是’, ‘相互’, ‘独立’, ‘的’, ‘,’, ‘不’, ‘存在’, ‘顺序’, ‘或’, ‘距离’, ‘关系’, ‘。’]]

# 词汇表
word_index = {}
index_word = {}

for i, word in enumerate(set(" ".join(texts))):
    word_index[word] = i
    index_word[i] = word

print(word_index)
print(index_word)


# 将文本转化为整数序列
sequences = [[word_index[word] for word in text] for text in texts]
print(sequences)


word_index: {')': 0, '例': 1, '。': 2, '于': 3, '到': 4, '地': 5, '错': 6, '避': 7, '列': 8, '引': 9, '每': 10, 'h': 11, '相': 12, '为': 13, '序': 14, '间': 15, '顺': 16, '式': 17, '免': 18, '提': 19, '模': 20, '比': 21, '表': 22, '也': 23, '入': 24, '离': 25, '字': 26, '映': 27, '的': 28, '互': 29, '中': 30, 'o': 31, '2': 32, '同': 33, '本': 34, '方': 35, '类': 36, '就': 37, '系': 38, '余': 39, '者': 40, '3': 41, '一': 42, '问': 43, 'e': 44, '以': 45, '距': 46, '采': 47, '元': 48, '-': 49, '射': 50, '实': 51, '热': 52, '误': 53, '、': 54, '观': 55, 'n': 56, '样': 57, '上': 58, '示': 59, '1': 60, '(': 61, '用': 62, '际': 63, '能': 64, '但': 65, '有': 66, '想': 67, '认': 68, '关': 69, '其': 70, '是': 71, '会': 72, '直': 73, '存': 74, '个': 75, '这': 76, '不': 77, '如': 78, '只': 79, '可': 80, '量': 81, 't': 82, '些': 83, '0': 84, '意': 85, '素': 86, '三': 87, '之': 88, '种': 89, '典': 90, '具': 91, '立': 92, ',': 93, '或': 94, '在': 95, '分': 96, '独': 97, '称': 98, '面': 99, '了': 100, '基': 101, '值': 102, '型': 103, '编': 104, '别': 105, '将': 106, '题': 107, '向': 108, '对': 109, '和': 110, '码': 111, '较': 112, '而': 113, '思': 114, ' ': 115, '义': 116}
index_word: {0: ')', 1: '例', 2: '。', 3: '于', 4: '到', 5: '地', 6: '错', 7: '避', 8: '列', 9: '引', 10: '每', 11: 'h', 12: '相', 13: '为', 14: '序', 15: '间', 16: '顺', 17: '式', 18: '免', 19: '提', 20: '模', 21: '比', 22: '表', 23: '也', 24: '入', 25: '离', 26: '字', 27: '映', 28: '的', 29: '互', 30: '中', 31: 'o', 32: '2', 33: '同', 34: '本', 35: '方', 36: '类', 37: '就', 38: '系', 39: '余', 40: '者', 41: '3', 42: '一', 43: '问', 44: 'e', 45: '以', 46: '距', 47: '采', 48: '元', 49: '-', 50: '射', 51: '实', 52: '热', 53: '误', 54: '、', 55: '观', 56: 'n', 57: '样', 58: '上', 59: '示', 60: '1', 61: '(', 62: '用', 63: '际', 64: '能', 65: '但', 66: '有', 67: '想', 68: '认', 69: '关', 70: '其', 71: '是', 72: '会', 73: '直', 74: '存', 75: '个', 76: '这', 77: '不', 78: '如', 79: '只', 80: '可', 81: '量', 82: 't', 83: '些', 84: '0', 85: '意', 86: '素', 87: '三', 88: '之', 89: '种', 90: '典', 91: '具', 92: '立', 93: ',', 94: '或', 95: '在', 96: '分', 97: '独', 98: '称', 99: '面', 100: '了', 101: '基', 102: '值', 103: '型', 104: '编', 105: '别', 106: '将', 107: '题', 108: '向', 109: '对', 110: '和', 111: '码', 112: '较', 113: '而', 114: '思', 115: ' ', 116: '义'}
sequences: [[21, 112, 73, 55, 28, 104, 111, 35, 17, 71, 47, 62, 58, 99, 19, 4, 28, 26, 90, 14, 8, 2], [1, 78, 93, 109, 3, 42, 75, 66, 87, 75, 36, 105, 28, 43, 107, 93, 80, 45, 62, 60, 54, 32, 110, 41, 96, 105, 22, 59, 76, 87, 75, 36, 105, 2], [65, 71, 93, 76, 89, 104, 111, 35, 17, 74, 95, 42, 75, 43, 107, 93, 37, 71, 20, 103, 80, 64, 72, 6, 53, 5, 68, 13, 77, 33, 36, 105, 88, 15, 74, 95, 42, 83, 16, 14, 94, 46, 25, 69, 38, 93, 113, 51, 63, 58, 76, 83, 69, 38, 80, 64, 71, 77, 74, 95, 28, 94, 40, 77, 91, 66, 51, 63, 85, 116, 28, 2], [13, 100, 7, 18, 76, 89, 43, 107, 93, 9, 24, 100, 31, 56, 44, 49, 11, 31, 82, 104, 111, 61, 23, 98, 97, 52, 104, 111, 0, 2], [31, 56, 44, 49, 11, 31, 82, 104, 111, 28, 101, 34, 114, 67, 71, 106, 10, 75, 36, 105, 27, 50, 4, 42, 75, 108, 81, 93, 70, 30, 79, 66, 42, 75, 48, 86, 28, 102, 13, 60, 93, 70, 39, 48, 86, 28, 102, 13, 84, 2], [76, 57, 93, 10, 75, 36, 105, 88, 15, 37, 71, 12, 29, 97, 92, 28, 93, 77, 74, 95, 16, 14, 94, 46, 25, 69, 38, 2]]
vocab_size = len(word_index) # 词汇表大小
embedding_dim = 4 # 嵌入向量的维度

# 创建一个Embedding层
embedding = nn.Embedding(vocab_size, embedding_dim)

# 输入转换为索引
input_sequences = [torch.tensor(sequence, dtype=torch.long) for sequence in sequences]
print(input_sequences)

# 词嵌入
embedding_sequences = [embedding(sequence) for sequence in input_sequences]
print(embedding_sequences)
[tensor([ 21, 112,  73,  55,  28, 104, 111,  35,  17,  71,  47,  62,  58,  99,
         19,   4,  28,  26,  90,  14,   8,   2]), tensor([  1,  78,  93, 109,   3,  42,  75,  66,  87,  75,  36, 105,  28,  43,
        107,  93,  80,  45,  62,  60,  54,  32, 110,  41,  96, 105,  22,  59,
         76,  87,  75,  36, 105,   2]), tensor([ 65,  71,  93,  76,  89, 104, 111,  35,  17,  74,  95,  42,  75,  43,
        107,  93,  37,  71,  20, 103,  80,  64,  72,   6,  53,   5,  68,  13,
         77,  33,  36, 105,  88,  15,  74,  95,  42,  83,  16,  14,  94,  46,
         25,  69,  38,  93, 113,  51,  63,  58,  76,  83,  69,  38,  80,  64,
         71,  77,  74,  95,  28,  94,  40,  77,  91,  66,  51,  63,  85, 116,
         28,   2]), tensor([ 13, 100,   7,  18,  76,  89,  43, 107,  93,   9,  24, 100,  31,  56,
         44,  49,  11,  31,  82, 104, 111,  61,  23,  98,  97,  52, 104, 111,
          0,   2]), tensor([ 31,  56,  44,  49,  11,  31,  82, 104, 111,  28, 101,  34, 114,  67,
         71, 106,  10,  75,  36, 105,  27,  50,   4,  42,  75, 108,  81,  93,
         70,  30,  79,  66,  42,  75,  48,  86,  28, 102,  13,  60,  93,  70,
         39,  48,  86,  28, 102,  13,  84,   2]), tensor([ 76,  57,  93,  10,  75,  36, 105,  88,  15,  37,  71,  12,  29,  97,
         92,  28,  93,  77,  74,  95,  16,  14,  94,  46,  25,  69,  38,   2])]
[tensor([[-0.8271, -3.4207,  1.0076,  1.1374],
        [ 1.1116, -0.4994,  0.0981,  0.3237],
        [ 2.7497,  0.1902, -0.8299,  0.5370],
        [-0.2703, -2.3045,  2.2617,  0.4600],
        [ 0.7530,  0.3279,  0.5744, -0.4870],
        [ 2.3420,  0.9236, -0.7709,  2.1097],
        [-0.8132,  0.4612,  0.6180, -0.7583],
        [ 1.1512, -0.3037,  0.1258,  0.4924],
        [-0.1787,  1.3683, -1.2654,  1.1520],
        [-0.3326,  1.7316, -0.2990, -0.5603],
        [ 0.3380,  1.2284, -0.0406, -0.5216],
        [ 0.0777, -0.8270, -1.3772, -0.1009],
        [ 0.3611,  0.3317,  0.4928, -0.7461],
        [ 0.0747, -0.8892,  0.4481,  0.3589],
        [ 0.7482,  0.8593,  0.5952,  0.6671],
        [-0.2818, -1.2419, -0.2863,  0.3172],
        [ 0.7530,  0.3279,  0.5744, -0.4870],
        [-0.1206,  0.3578, -0.9191, -1.6470],
        [ 0.4042, -1.4756, -0.5672, -2.0382],
        [-1.0408, -0.8669,  0.5141, -1.4994],
        [ 0.1042, -1.7858,  1.8844,  0.4867],
        [ 1.7992,  2.0334,  0.0434, -0.7814]], grad_fn=<EmbeddingBackward0>), tensor([[-0.3582, -0.7115, -1.7531, -0.9850],
        [ 0.5675, -0.8603, -0.4381,  0.6289],
        [ 0.1376,  1.3232,  1.7879,  0.5816],
        [ 0.9818,  1.1166, -0.8948, -1.2513],
        [ 1.0693,  0.1953, -2.3745, -1.1714],
        [ 0.4113, -0.7838,  1.6471, -1.3885],
        [-1.6262, -0.2523, -0.2898, -1.2875],
        [-1.8421,  2.4769,  2.7960,  1.0011],
        [-1.7831, -0.5166, -0.4787,  1.4873],
        [-1.6262, -0.2523, -0.2898, -1.2875],
        [ 0.2028, -0.5652,  1.0369, -0.1741],
        [ 0.4239, -0.0271,  0.2380, -0.2279],
        [ 0.7530,  0.3279,  0.5744, -0.4870],
        [-0.2454,  0.8477, -0.6346,  1.4303],
        [-0.3238,  0.2671, -1.1727, -0.8511],
        [ 0.1376,  1.3232,  1.7879,  0.5816],
        [-0.3294,  1.5733, -0.6729,  0.7697],
        [-0.4431, -1.1818, -0.9103,  1.1734],
        [ 0.0777, -0.8270, -1.3772, -0.1009],
        [-0.4029, -1.0039, -1.4519,  1.2131],
        [-0.3353, -0.0749,  1.5571,  1.0565],
        [-0.3189,  1.2529,  1.1855, -0.2270],
        [ 0.4843,  0.2976,  1.2031,  2.9554],
        [ 1.2091,  0.1142,  1.3847, -0.1088],
        [ 0.5498, -0.1714, -0.2106, -1.7907],
        [ 0.4239, -0.0271,  0.2380, -0.2279],
        [-1.4678, -0.2963, -1.7785,  0.8152],
        [-1.6032, -0.0320,  0.5807,  1.8634],
        [ 1.7003, -0.7403, -1.4443,  0.9203],
        [-1.7831, -0.5166, -0.4787,  1.4873],
        [-1.6262, -0.2523, -0.2898, -1.2875],
        [ 0.2028, -0.5652,  1.0369, -0.1741],
        [ 0.4239, -0.0271,  0.2380, -0.2279],
        [ 1.7992,  2.0334,  0.0434, -0.7814]], grad_fn=<EmbeddingBackward0>), tensor([[-0.2287,  0.8708, -1.2204, -0.4580],
        [-0.3326,  1.7316, -0.2990, -0.5603],
        [ 0.1376,  1.3232,  1.7879,  0.5816],
        [ 1.7003, -0.7403, -1.4443,  0.9203],
        [ 1.3180,  1.0741, -1.4830, -0.1314],
        [ 2.3420,  0.9236, -0.7709,  2.1097],
        [-0.8132,  0.4612,  0.6180, -0.7583],
        [ 1.1512, -0.3037,  0.1258,  0.4924],
        [-0.1787,  1.3683, -1.2654,  1.1520],
        [-0.7300,  0.7559, -0.2909, -1.1319],
        [ 0.6367,  0.0270, -1.1265, -0.5191],
        [ 0.4113, -0.7838,  1.6471, -1.3885],
        [-1.6262, -0.2523, -0.2898, -1.2875],
        [-0.2454,  0.8477, -0.6346,  1.4303],
        [-0.3238,  0.2671, -1.1727, -0.8511],
        [ 0.1376,  1.3232,  1.7879,  0.5816],
        [ 0.3678,  0.2178, -0.6132,  1.1892],
        [-0.3326,  1.7316, -0.2990, -0.5603],
        [ 1.7230,  0.2492, -1.1080,  0.5950],
        [-1.3568, -1.1509,  1.4420, -0.6164],
        [-0.3294,  1.5733, -0.6729,  0.7697],
        [ 0.7327,  0.1396,  1.9950, -1.3080],
        [-0.6989,  0.0221, -1.0041, -0.7569],
        [ 0.1202,  0.5948, -1.8171,  2.5245],
        [ 1.0069,  0.8075, -0.0697,  0.7651],
        [-0.0585, -0.4487,  0.2124, -0.0717],
        [-0.3807,  0.2195, -1.1130, -0.5182],
        [-0.3033,  0.1475,  0.8471, -2.1588],
        [-0.8152, -0.0837,  0.5880, -0.6390],
        [-0.0509,  1.0256, -0.9774,  0.8224],
        [ 0.2028, -0.5652,  1.0369, -0.1741],
        [ 0.4239, -0.0271,  0.2380, -0.2279],
        [ 1.9825,  1.0140, -0.1613,  1.0952],
        [ 0.9012,  0.2991,  0.6070, -1.6379],
        [-0.7300,  0.7559, -0.2909, -1.1319],
        [ 0.6367,  0.0270, -1.1265, -0.5191],
        [ 0.4113, -0.7838,  1.6471, -1.3885],
        [-0.8191,  2.3775,  1.1809, -0.2766],
        [-1.0523, -0.4777,  0.2370, -1.1320],
        [-1.0408, -0.8669,  0.5141, -1.4994],
        [ 0.6031,  0.1546,  0.5192,  0.8785],
        [-0.1887,  2.5027, -1.1763, -0.5910],
        [ 2.4997,  0.1289, -1.6786, -1.1953],
        [-0.1588,  0.7292,  0.7088,  0.1943],
        [ 0.2756, -1.3107, -0.0064,  0.6357],
        [ 0.1376,  1.3232,  1.7879,  0.5816],
        [-0.2187, -0.0971,  2.1553,  0.0777],
        [ 0.3097, -0.4480, -1.4339,  0.2057],
        [-0.9975,  0.0306, -0.0675,  0.2023],
        [ 0.3611,  0.3317,  0.4928, -0.7461],
        [ 1.7003, -0.7403, -1.4443,  0.9203],
        [-0.8191,  2.3775,  1.1809, -0.2766],
        [-0.1588,  0.7292,  0.7088,  0.1943],
        [ 0.2756, -1.3107, -0.0064,  0.6357],
        [-0.3294,  1.5733, -0.6729,  0.7697],
        [ 0.7327,  0.1396,  1.9950, -1.3080],
        [-0.3326,  1.7316, -0.2990, -0.5603],
        [-0.8152, -0.0837,  0.5880, -0.6390],
        [-0.7300,  0.7559, -0.2909, -1.1319],
        [ 0.6367,  0.0270, -1.1265, -0.5191],
        [ 0.7530,  0.3279,  0.5744, -0.4870],
        [ 0.6031,  0.1546,  0.5192,  0.8785],
        [ 0.3040,  1.6729, -1.2441, -2.0695],
        [-0.8152, -0.0837,  0.5880, -0.6390],
        [-0.0372, -1.0394,  0.9333, -1.7641],
        [-1.8421,  2.4769,  2.7960,  1.0011],
        [ 0.3097, -0.4480, -1.4339,  0.2057],
        [-0.9975,  0.0306, -0.0675,  0.2023],
        [-0.1625, -1.5905, -0.6695, -0.4197],
        [ 0.8374,  0.4385,  0.5106, -1.1942],
        [ 0.7530,  0.3279,  0.5744, -0.4870],
        [ 1.7992,  2.0334,  0.0434, -0.7814]], grad_fn=<EmbeddingBackward0>), tensor([[-0.3033,  0.1475,  0.8471, -2.1588],
        [ 0.8569,  1.0863,  0.4669,  1.1277],
        [-1.2574,  0.7327,  0.5459,  1.3093],
        [-0.2403, -2.1121,  0.0654, -0.5656],
        [ 1.7003, -0.7403, -1.4443,  0.9203],
        [ 1.3180,  1.0741, -1.4830, -0.1314],
        [-0.2454,  0.8477, -0.6346,  1.4303],
        [-0.3238,  0.2671, -1.1727, -0.8511],
        [ 0.1376,  1.3232,  1.7879,  0.5816],
        [ 1.3658,  0.4408,  0.2941,  1.6885],
        [ 0.3223, -1.1167,  0.3408, -0.1121],
        [ 0.8569,  1.0863,  0.4669,  1.1277],
        [ 0.9538,  0.4298, -1.4315,  1.1654],
        [-0.7486,  1.4318,  0.1761,  0.5564],
        [-0.8676,  1.8610, -0.7512, -1.1514],
        [ 0.9734,  1.7933, -1.2475,  0.8386],
        [ 0.4891,  0.0081,  0.1878, -0.5701],
        [ 0.9538,  0.4298, -1.4315,  1.1654],
        [ 0.7189,  1.3097, -0.7027,  0.1364],
        [ 2.3420,  0.9236, -0.7709,  2.1097],
        [-0.8132,  0.4612,  0.6180, -0.7583],
        [-1.0550,  1.0984,  0.7286, -0.2470],
        [-0.9643,  1.2373, -0.9732,  0.5179],
        [ 1.7992,  0.3802,  0.8440, -0.1815],
        [-0.3983, -0.7383, -1.9083, -0.9365],
        [-0.9057, -0.1811,  1.5420,  0.4628],
        [ 2.3420,  0.9236, -0.7709,  2.1097],
        [-0.8132,  0.4612,  0.6180, -0.7583],
        [ 0.7620,  0.3044,  0.3303,  0.2478],
        [ 1.7992,  2.0334,  0.0434, -0.7814]], grad_fn=<EmbeddingBackward0>), tensor([[ 0.9538,  0.4298, -1.4315,  1.1654],
        [-0.7486,  1.4318,  0.1761,  0.5564],
        [-0.8676,  1.8610, -0.7512, -1.1514],
        [ 0.9734,  1.7933, -1.2475,  0.8386],
        [ 0.4891,  0.0081,  0.1878, -0.5701],
        [ 0.9538,  0.4298, -1.4315,  1.1654],
        [ 0.7189,  1.3097, -0.7027,  0.1364],
        [ 2.3420,  0.9236, -0.7709,  2.1097],
        [-0.8132,  0.4612,  0.6180, -0.7583],
        [ 0.7530,  0.3279,  0.5744, -0.4870],
        [ 3.1833,  0.9793, -2.2142,  1.1424],
        [ 0.5528,  1.0870, -0.4082,  0.0372],
        [-0.3247,  0.7986,  1.3202, -0.1900],
        [ 0.2143,  0.3669,  0.0974, -0.2620],
        [-0.3326,  1.7316, -0.2990, -0.5603],
        [ 1.6495,  0.8272, -0.2385, -0.6014],
        [-1.0911,  1.4931,  0.0144,  0.0616],
        [-1.6262, -0.2523, -0.2898, -1.2875],
        [ 0.2028, -0.5652,  1.0369, -0.1741],
        [ 0.4239, -0.0271,  0.2380, -0.2279],
        [-1.7269, -2.1262, -1.1280, -1.4431],
        [ 0.3995, -0.3486,  0.6831, -1.2386],
        [-0.2818, -1.2419, -0.2863,  0.3172],
        [ 0.4113, -0.7838,  1.6471, -1.3885],
        [-1.6262, -0.2523, -0.2898, -1.2875],
        [-0.1443,  0.4913, -0.7174, -1.1175],
        [-0.9000,  2.3307, -0.8306, -0.8913],
        [ 0.1376,  1.3232,  1.7879,  0.5816],
        [-0.2269, -0.0196,  0.7214,  0.8138],
        [-0.7023,  1.2319, -1.0195,  0.3381],
        [ 1.3086, -0.2187,  0.1623,  1.3076],
        [-1.8421,  2.4769,  2.7960,  1.0011],
        [ 0.4113, -0.7838,  1.6471, -1.3885],
        [-1.6262, -0.2523, -0.2898, -1.2875],
        [ 1.6790,  0.3225, -3.1174, -1.4824],
        [ 0.8470,  1.6889, -1.0368,  1.4659],
        [ 0.7530,  0.3279,  0.5744, -0.4870],
        [-0.6789,  0.2614,  1.2435, -2.5316],
        [-0.3033,  0.1475,  0.8471, -2.1588],
        [-0.4029, -1.0039, -1.4519,  1.2131],
        [ 0.1376,  1.3232,  1.7879,  0.5816],
        [-0.2269, -0.0196,  0.7214,  0.8138],
        [-1.1739, -0.7361,  1.1430,  0.0355],
        [ 1.6790,  0.3225, -3.1174, -1.4824],
        [ 0.8470,  1.6889, -1.0368,  1.4659],
        [ 0.7530,  0.3279,  0.5744, -0.4870],
        [-0.6789,  0.2614,  1.2435, -2.5316],
        [-0.3033,  0.1475,  0.8471, -2.1588],
        [ 0.7528, -0.2440,  0.4287, -1.0538],
        [ 1.7992,  2.0334,  0.0434, -0.7814]], grad_fn=<EmbeddingBackward0>), tensor([[ 1.7003, -0.7403, -1.4443,  0.9203],
        [ 0.2293,  0.1013, -0.4674,  2.6180],
        [ 0.1376,  1.3232,  1.7879,  0.5816],
        [-1.0911,  1.4931,  0.0144,  0.0616],
        [-1.6262, -0.2523, -0.2898, -1.2875],
        [ 0.2028, -0.5652,  1.0369, -0.1741],
        [ 0.4239, -0.0271,  0.2380, -0.2279],
        [ 1.9825,  1.0140, -0.1613,  1.0952],
        [ 0.9012,  0.2991,  0.6070, -1.6379],
        [ 0.3678,  0.2178, -0.6132,  1.1892],
        [-0.3326,  1.7316, -0.2990, -0.5603],
        [-1.1170, -1.4170,  1.0364,  0.7652],
        [ 0.9543, -1.4128,  1.1750, -0.2770],
        [-0.3983, -0.7383, -1.9083, -0.9365],
        [ 0.0483, -0.3987,  0.3003, -0.7949],
        [ 0.7530,  0.3279,  0.5744, -0.4870],
        [ 0.1376,  1.3232,  1.7879,  0.5816],
        [-0.8152, -0.0837,  0.5880, -0.6390],
        [-0.7300,  0.7559, -0.2909, -1.1319],
        [ 0.6367,  0.0270, -1.1265, -0.5191],
        [-1.0523, -0.4777,  0.2370, -1.1320],
        [-1.0408, -0.8669,  0.5141, -1.4994],
        [ 0.6031,  0.1546,  0.5192,  0.8785],
        [-0.1887,  2.5027, -1.1763, -0.5910],
        [ 2.4997,  0.1289, -1.6786, -1.1953],
        [-0.1588,  0.7292,  0.7088,  0.1943],
        [ 0.2756, -1.3107, -0.0064,  0.6357],
        [ 1.7992,  2.0334,  0.0434, -0.7814]], grad_fn=<EmbeddingBackward0>)]
embedding_bag = nn.EmbeddingBag(vocab_size, embedding_dim, mode='mean')
input_sequence = torch.cat(input_sequences)
offsets = torch.tensor([0, len(input_sequences[0])], dtype=torch.long)
embedding_bag_res = embedding_bag(input_sequence, offsets)
print(embedding_bag_res)

tensor([[-0.3011, -0.3997, -0.2998, -0.0833],
        [-0.0472,  0.0584,  0.1322, -0.1476]], grad_fn=<EmbeddingBagBackward0>)

参考:
《神经网络与深度学习》邱锡鹏

  • 26
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值