Pytorch 中的Embedding理解

1 Word Embed

理解Embeding首先需要理解Word Embed,那什么是Word Embed呢?

即将word转化为tensor:

This is an apple!  ->  [[0.1, 0.2], [0.2, 0.3], [0.1, 0.4], [0.2, 0.2]]

词嵌入的过程,就相当于是我们在给计算机制造出一本字典的过程, key为index或者词,value为向量

通过这个字典来间接地识别文字,那这个如何理解,首先就要知道one-hot向量

1.1 One-hot编码

rap   -> [0, 0, 0, 0, 1, 0]
apple -> [0, 0, 0, 1, 0, 0]
is    -> [0, 0, 1, 0, 0, 0]
red   -> [0, 1, 0, 0, 0, 0]

每一个单词都被表示为one-hot向量,建立了一个字典,通过[0, 0, 0, 0, 1]能查询到单词this, 但是one-hot 编码有两个问题:

  1. 难以对大的文本数据建模,想象一下,对于有上万单词的文本建立one-hot的维度就有上万维度
  2. 难以表示文本之间的语义关系

对于第一个问题,这个比较好理解

对于第二个问题:

Rap    -> [0, 0, 0, 0, 1, 0]
banana -> [1, 0, 0, 0, 0, 0]
is     -> [0, 0, 1, 0, 0, 0]
yellow -> [0, 0, 0, 0, 0, 1]

apple和banana是有语义关系的,它们都是水果,都是成熟的,但是使用one-hot向量难以表示其内在的关系,使用词嵌入来解决这个问题

1.2 Word Embed理解

可以知道apple和banana都是水果,它们都是熟的,一个生长在热带一个生长在温带,根据这三个关系,我们建立词嵌入

		rap   friut tropical
apple   [1.0, 1.0,  0.2]
banana  [1.0, 1.0,  0.8]

因此我们根据这样的规则就可以建立一个词嵌入,并能表示他们的关系, 因此提出词嵌入:Word Embed,它是将词按照一定的规则映射为一个向量

同时可知他们之间的关系:

apple_embed = banana_embed[0] * 1 + banana_embed[1] * 1 + banana_embed[2] * 0.25

1.3 总结

one-hotquery-embed

  • 单词的向量表示从稀疏态变成了密集态
  • 相互独立向量建立了内在联系的关系

2 Pytorch中的embeding

Doc: https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html?highlight=embeding

A simple lookup table that stores embeddings of a fixed dictionary and size.

This module is often used to store word embeddings and retrieve them using indices. The input to the module is a list of indices, and the output is the corresponding word embeddings.

建立一个字典:将一个scale(index) 以一定的规则映射为一个tensor,使用index进行索引,并且这个规则是可学习的

参数:

  • init: num_embeddings:下标的长度 , embedding_dim :向量的维度
  • Input: 输入下标大小,输入范围为:>=0 < num_embedding,按照词的index进行编码
  • output: 输出编码后的向量,前面的维度与input的前面的维度相同,维度为embedding_dim
  • 初始化的规则:均值为0,方差为1的正态分布,并且是可学习的
weight (Tensor) – the learnable weights of the module of shape (num_embeddings, embedding_dim) initialized from N(0,1)

3 Example

3.1 Ex1 一维编码

import torch
import torch.nn as nn

# 定义词嵌入
num_query = 10
hidden_dim = 3
embed = nn.Embedding(num_query, hidden_dim)

# 定义输入
input_tensor = torch.tensor([1, 2, 3, 1])
b = embed(input_tensor)


# 输出系数
embed.weight, b

>>>

(Parameter containing:
 tensor([[ 0.0739,  0.1857,  0.0616],
         [ 0.5078,  0.9804,  1.2938],
         [-1.6650,  0.1880,  0.3921],
         [ 0.7753,  1.5514,  0.2877],
         [ 0.1817, -0.7859, -0.7943],
         [-1.5885,  0.3873,  0.2072],
         [-0.4109, -2.1112,  1.7696],
         [ 0.6706, -0.1292, -2.3005],
         [ 1.3861, -0.5765,  1.3436],
         [-0.8233, -0.1140,  0.6213]], requires_grad=True),
 tensor([[ 0.5078,  0.9804,  1.2938],
         [-1.6650,  0.1880,  0.3921],
         [ 0.7753,  1.5514,  0.2877],
         [ 0.5078,  0.9804,  1.2938]], grad_fn=<EmbeddingBackward0>))

3.2 Ex2 多维度编码

import torch
import torch.nn as nn

# 定义词嵌入
num_query = 10
hidden_dim = 3
embed = nn.Embedding(num_query, hidden_dim)

# 定义输入
input_tensor = torch.tensor([[1, 2], [3, 1]])
b = embed(input_tensor)


# 输出系数
embed.weight, b

>>>
(Parameter containing:
 tensor([[-0.1994, -0.7068,  0.1735],
         [-0.6465, -0.1401, -0.1157],
         [-0.3184,  0.5307, -0.2699],
         [-1.2810,  2.5731, -2.0119],
         [-0.5721, -0.9300, -0.9147],
         [ 1.0049,  0.5798,  0.0149],
         [ 0.5879, -1.5378,  1.4790],
         [ 0.8585,  1.5588, -1.8052],
         [ 0.6283, -1.2635,  0.8070],
         [-1.5081, -0.2307, -1.0041]], requires_grad=True),
 tensor([[[-0.6465, -0.1401, -0.1157],
          [-0.3184,  0.5307, -0.2699]],
 
         [[-1.2810,  2.5731, -2.0119],
          [-0.6465, -0.1401, -0.1157]]], grad_fn=<EmbeddingBackward0>))

3.3 超过下标后报错

import torch
import torch.nn as nn

# 定义词嵌入
num_query = 10
hidden_dim = 3
embed = nn.Embedding(num_query, hidden_dim)

# 定义输入
input_tensor = torch.tensor([0, 10])
b = embed(input_tensor)


# 输出系数
embed.weight, b

>>>

IndexError Traceback (most recent call last) /tmp/ipykernel_2335/1048340143.py in <module>
      9 # 定义输入
     10 input_tensor = torch.tensor([0, 10])
---> 11 b = embed(input_tensor)
********
IndexError: index out of range in self

通过上面的实验可以看出每次embed.weight都是随机初始化的,下标查询的维度不可超出num_embeddings


参考

pytorch embedding层详解(从原理到实战)https://blog.csdn.net/weixin_43914889/article/details/104699657

深入理解 Embedding层的本质 https://blog.csdn.net/weixin_42078618/article/details/84553940

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值