在做机器学习的任务时,经常需要将一些人类理解的语句转为机器理解的数值符号,一种比较常见的做法是将词语或句子转换成向量的形式。在这个转换的过程中,可以利用 PyTorch 中提供的torch.nn.Embedding 接口来实现词语的向量化,下面将对 torch.nnEmbedding 接口使用进行演示。
一、接口介绍
接口官方地址:Embedding — PyTorch 1.13 documentation
CLASS
torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, device=None, dtype=None)
torch.nn.Embedding 接口模块本质上是一个字典(A lookup table),字典中每个索引对应一个词的 Embedding 向量形式。这个向量是随机初始化的(满足正态分布 N ( 0 , 1 ) N(0,1) N(0,1) 中随机取值),即不代表任何含义,所以不会有 word2vec 等方法训练出来的效果,但是可以利用这样的方法先赋值,之后再进行学习。
接口的常用参数
-
num_embeddings (int) – 字典中的词语数
-
embedding_dim (int)– 嵌入向量的维度
-
padding_idx (int, optional) – 如果给定,则对 padding_idx 中索引对应的位置填0
二、接口使用
# 词语集
words = ['A', 'B', 'C', 'D', 'E']
# 词号映射
word_idx = {w: idx for idx, w in enumerate(words)}
# 语句集
sentences = [
'AABBC',
'AADDE',
'EECAA'
]
num_embed = len(words) # 词典的单词数
embed_dim = 3 # 嵌入向量的维度(自定义)
embed = nn.Embedding(num_embed, embed_dim)
# 将句子中的每个单词转成对应的编号
input = []
for s in sentences:
s_idx = [word_idx[w] for w in s]
input.append(s_idx)
# 对语句集中的语句进行 embedding
input = torch.LongTensor(input)
print(embed(input))
其中,输入的 input 是将语句集编号化后得到的数字集,输出的结果是每个句子的嵌入表示:
Input:
tensor([[0, 0, 1, 1, 2],
[0, 0, 3, 3, 4],
[4, 4, 2, 0, 0]])
Embedding Result:
tensor([[[ 0.3174, -0.1958, -1.1196],
[ 0.3174, -0.1958, -1.1196],
[ 0.5466, -0.6627, -2.0538],
[ 0.5466, -0.6627, -2.0538],
[ 2.2772, -0.3313, 0.3458]],
[[ 0.3174, -0.1958, -1.1196],
[ 0.3174, -0.1958, -1.1196],
[ 1.0199, 0.1071, 2.3243],
[ 1.0199, 0.1071, 2.3243],
[-0.0983, -1.4985, -0.1011]],
[[-0.0983, -1.4985, -0.1011],
[-0.0983, -1.4985, -0.1011],
[ 2.2772, -0.3313, 0.3458],
[ 0.3174, -0.1958, -1.1196],
[ 0.3174, -0.1958, -1.1196]]], grad_fn=<EmbeddingBackward0>)
从 Embedding Result 中可以看出,此时每个单词已经被转换成维度为 3 的向量表示,而由于初始设置的语句由 5 个单词组成,因此,每个语句被映射为由 5 个单词向量组成的二维矩阵。当需要查询每一个单词对应的具体向量时,可以查看 embedding.weight,其中每一个向量对应 index 指向的词,即由 embedding.weight 存储对应的词向量,需要时直接在 embedding.weight 中查询。
print(embed.weight)
Parameter containing:
tensor([[ 0.3174, -0.1958, -1.1196],
[ 0.5466, -0.6627, -2.0538],
[ 2.2772, -0.3313, 0.3458],
[ 1.0199, 0.1071, 2.3243],
[-0.0983, -1.4985, -0.1011]], requires_grad=True)