【Pytorch学习】nn.Embedding的讲解及使用


Pytorch官网

https://pytorch.org/docs/stable/index.html


nn.Embedding

官方文档的解释

https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html#torch.nn.Embedding
在这里插入图片描述
在这里插入图片描述

  • 一个简单的查找表,存储一个固定字典的嵌入(embedding)和大小
  • 这个模块通常被用于存储词嵌入,使用索引检索出来这些词嵌入(word embedding)。输入是一列索引(indices),输出是对应的词嵌入

举例分析

>>> weight = torch.FloatTensor([[1, 2.1, 3], 
                                [4.2, 5.1, 6], 
                                [7.1, 8, 9.4]])
>>> embedding = nn.Embedding.from_pretrained(weight)
>>> input = torch.LongTensor([1, 2])
>>> embedding(input)
tensor([[4.2000, 5.1000, 6.0000],
        [7.1000, 8.0000, 9.4000]])
  • embedding = nn.Embedding(10, 3)定义了一个有10个张量,每个张量大小为3的embedding模块
  • input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])输入是有两个样本,每个样本有4个索引的batch
  • embedding(input)将input放进这个embedding模块,它的输出是一个[2, 4, 3]的张量

进一步分析这个输出的[2, 4, 3]的张量是怎么得到的
我们打印一下这个Embedding模型的weight:


>>> embedding.weight
Parameter containing:
tensor([[-1.1994, -0.1827, -1.3696],
        [-0.4404, -0.4694,  1.2834],
        [-0.2315,  0.8612,  0.4055],
        [-1.4578, -1.5296, -1.1858],
        [ 1.2688, -1.0777, -1.6466],
        [ 0.3597,  1.3453,  1.4866],
        [ 1.3760,  0.2611, -0.6241],
        [-2.3062, -1.0274,  0.3889],
        [ 0.4912, -0.1535, -0.5067],
        [ 0.2230, -1.6161,  1.1694]], requires_grad=True)

这个weight的shape就是最开始我们给他的参数[10, 3],那具体weight和输出(embedding(input))之间的映射关系是怎样的呢?把他俩放在一块看一下:
在这里插入图片描述1. 可以发现input的索引首先从weight中找到索引对应的值,然后所有索引在embedding.weight中对应值所形成的张量就是embedding(input)

例如:
索引1所对应的就是[-0.4404, -0.4694, 1.2834]
索引2所对应的就是[-0.2315, 0.8612, 0.4055]
索引4所对应的就是[ 1.2688, -1.0777, -1.6466]
索引5所对应的就是[ 0.3597, 1.3453, 1.4866]
input中的[1, 2, 4, 5]每个索引在embedding.weight中对应的几个值,就对应左侧输出的红框部分

  1. embedding(input)的维度就是[2, 4, 3]:输入input的维度[2, 4]多加一个Embedding模块[10, 3]的第二个维度

nn.Embedding.from_pretrained

官方文档的解释

在这里插入图片描述在这里插入图片描述embedding可以接受已经定义好的tensor作为weight,这个已经指定好的tensor必须是只有两个维度的FloatTensor

举例

>>> weight = torch.FloatTensor([[1, 2.1, 3], 
                                [4.2, 5.1, 6], 
                                [7.1, 8, 9.4]])
>>> embedding = nn.Embedding.from_pretrained(weight)
>>> input = torch.LongTensor([1, 2])
>>> embedding(input)
tensor([[4.2000, 5.1000, 6.0000],
        [7.1000, 8.0000, 9.4000]])

总结

总结了对nn.embedding的理解及使用,用输入索引映射出对应的词嵌入(embedding)。还有nn.Embedding.from_pretrained的使用。

努力学习多多总结,欢迎交流

参考

无脑入门pytorch系列(一)—— nn.embedding - 暗夜无风的文章 - 知乎
https://zhuanlan.zhihu.com/p/647536930

http://t.csdnimg.cn/mNCX2

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值