torch.nn.Embedding可将数据的离散表示转换成连续的表示。以NLP为例,若输入数据为中文句子,则每个汉字可表示成one-hot向量,但此方法若输入数据涉及的汉字数量过大,则输入维度可达几十万维,故可利用Embedding将输入数据进行嵌入表示。可将Embedding看作一个查询表(二维),一个汉字对应查询表的一行数据,则行数就代表所涉及汉字的总个数,列数则代表嵌入的维度(自己定)。 torch.nn.Embedding(num_embeddings, embedding_dim)的参数主要有num_embeddings和embedding_dim两个,其中num_embeddings代表查询表的大小(行数),embedding_dim代表嵌入维度。
import torch
embedding = torch.nn.Embedding(num_embeddings=4, embedding_dim=2) # 创建4x2的查询表
print(embedding.weight)
test = torch.randint(0, 4, (2, 2)) # 生成数据,维度为2x2
print(test)
test = embedding(test) # 将数据进行嵌入,输出维度为2x2x2
print(test)
"""
Parameter containing:
tensor([[ 0.1208, -1.6480],
[ 0.3138, -0.9083],
[-1.0675, 0.0510],
[-0.4471, 1.4563]], requires_grad=True)
tensor([[1, 3],
[2, 2]])
tensor([[[ 0.3138, -0.9083],
[-0.4471, 1.4563]],
[[-1.0675, 0.0510],
[-1.0675, 0.0510]]], grad_fn=<EmbeddingBackward0>)
"""
如图所示,经过Embedding后,生成的数据test中的1,3,2,2被查询表中行数索引为1,3,2,2的向量代替了。
注意:要进行嵌入的数据为整数(表示位置),且不能超过查询表的行数