白话解释
问题提出
下面代码中为什么y比x多了一个维度呢?
x的shape、y的shape中每个数字又表示了什么?
import torch
import torch.nn as nn
x = torch.tensor([[0, 1], [0, 3]])
print(x.shape) # torch.Size([2, 2])
e = nn.Embedding(num_embeddings=4, embedding_dim=5)
y = e(x)
print(y.shape) # torch.Size([2, 2, 5])
函数简介
torch.nn.embedding的作用就是 用 嵌入向量 表示 词索引。
torch.nn.embedding的常用参数如下:
torch.nn.Embedding(num_embeddings, embedding_dim)
- num_embeddings:表示词索引取值范围的长度。该长度值为该参数的最小值。
- embedding_dim:表示一个嵌入向量的长度。
举例
假设有一个小的词汇表,包含以下单词和它们对应的索引(注意索引是从0开始的):
- “hello” -> 0
- “world” -> 1
- “pytorch” -> 2
- “embedding” -> 3
如果现在想用tensor来达“hello world”、“hello pytorch”,那么只需要让这些单词的 词索引 组成一个tensor就好。
“hello world”对应的词索引表为[0,1]
"hello pytorch"对应的词索引表为[0,3]
下面用一个tensor来表示这两句话:
x = torch.tensor([[0,1],[0,3]])
x.shape#(2,2)
- 这里x的shape含义为:有两个序列(代表上面那两句话),每个序列均由两个词索引组成(每句话均由对应的词索引所代表的词组成)。
接下来来看embedding
e = nn.Embedding(num_embeddings=4, embedding_dim=5)
- num_embeddings=4:上面一共只有4个不同的词,索引范围为0~3,长度为4,所以这个参数为4。当然也可以取更高,但是不能低于4。
- embedding_dim=5:即词向量的长度为5。一个词向量表示一个词索引。这里5表示一个长度为5的词向量可以表示一个词索引。这里也可以不取5,其实可以取你想要的嵌入向量的长度数,取5只是举例。
最后来看下y的shape
y = e(x)
print(y.shape) # torch.Size([2, 2, 5])
- [2,2,5]表示有2个序列,每个序列由2个词索引组成,每个词索引均由一个长度为5的嵌入向量表示。
最后
torch.nn.Embedding的作用就是将 词索引 用 嵌入向量表示,于是就增加了一个维度,这个维度大小正是一个嵌入向量长度。