Pytorch官网
https://pytorch.org/docs/stable/index.html
nn.Embedding
官方文档的解释
https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html#torch.nn.Embedding
- 一个简单的查找表,存储一个固定字典的嵌入(embedding)和大小
- 这个模块通常被用于存储词嵌入,使用索引检索出来这些词嵌入(word embedding)。输入是一列索引(indices),输出是对应的词嵌入
举例分析
>>> weight = torch.FloatTensor([[1, 2.1, 3],
[4.2, 5.1, 6],
[7.1, 8, 9.4]])
>>> embedding = nn.Embedding.from_pretrained(weight)
>>> input = torch.LongTensor([1, 2])
>>> embedding(input)
tensor([[4.2000, 5.1000, 6.0000],
[7.1000, 8.0000, 9.4000]])
- embedding = nn.Embedding(10, 3)定义了一个有10个张量,每个张量大小为3的embedding模块
- input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])输入是有两个样本,每个样本有4个索引的batch
- embedding(input)将input放进这个embedding模块,它的输出是一个[2, 4, 3]的张量
进一步分析这个输出的[2, 4, 3]的张量是怎么得到的
我们打印一下这个Embedding模型的weight:
>>> embedding.weight
Parameter containing:
tensor([[-1.1994, -0.1827, -1.3696],
[-0.4404, -0.4694, 1.2834],
[-0.2315, 0.8612, 0.4055],
[-1.4578, -1.5296, -1.1858],
[ 1.2688, -1.0777, -1.6466],
[ 0.3597, 1.3453, 1.4866],
[ 1.3760, 0.2611, -0.6241],
[-2.3062, -1.0274, 0.3889],
[ 0.4912, -0.1535, -0.5067],
[ 0.2230, -1.6161, 1.1694]], requires_grad=True)
这个weight的shape就是最开始我们给他的参数[10, 3],那具体weight和输出(embedding(input))之间的映射关系是怎样的呢?把他俩放在一块看一下:
1. 可以发现input的索引首先从weight中找到索引对应的值,然后所有索引在embedding.weight中对应值所形成的张量就是embedding(input)
例如:
索引1所对应的就是[-0.4404, -0.4694, 1.2834]
索引2所对应的就是[-0.2315, 0.8612, 0.4055]
索引4所对应的就是[ 1.2688, -1.0777, -1.6466]
索引5所对应的就是[ 0.3597, 1.3453, 1.4866]
input中的[1, 2, 4, 5]每个索引在embedding.weight中对应的几个值,就对应左侧输出的红框部分
- embedding(input)的维度就是[2, 4, 3]:输入input的维度[2, 4]多加一个Embedding模块[10, 3]的第二个维度
nn.Embedding.from_pretrained
官方文档的解释
embedding可以接受已经定义好的tensor作为weight,这个已经指定好的tensor必须是只有两个维度的FloatTensor
举例
>>> weight = torch.FloatTensor([[1, 2.1, 3],
[4.2, 5.1, 6],
[7.1, 8, 9.4]])
>>> embedding = nn.Embedding.from_pretrained(weight)
>>> input = torch.LongTensor([1, 2])
>>> embedding(input)
tensor([[4.2000, 5.1000, 6.0000],
[7.1000, 8.0000, 9.4000]])
总结
总结了对nn.embedding的理解及使用,用输入索引映射出对应的词嵌入(embedding)。还有nn.Embedding.from_pretrained的使用。
努力学习多多总结,欢迎交流
参考
无脑入门pytorch系列(一)—— nn.embedding - 暗夜无风的文章 - 知乎
https://zhuanlan.zhihu.com/p/647536930