torch.nn.Embedding

文章目录

方法介绍

CLASS torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, _freeze=False, device=None, dtype=None)
  • 是一个存储固定字典和大小的 lookup table
  • 常用来存储word embedding,并使用索引来检索他们,输入是一个索引列表,输出是对应的word embeddings

主要参数介绍:

  • num_embeddings: embedding字典大小
  • embedding_dim:每个embedding向量的大小、
  • pddding_idx(optional):如果设置了 padding_idx,其对应的 entries 就不会对梯度有贡献。
    • 因此,在 padding_idx 位置的 embedding 在训练的时候不会更新,是一个固定的 pad
    • 对一个新创建的 Embedding,padding_idx处的embedding向量是全0,也可以用其他值来代替

例子

不含 padding_idx

input = torch.tensor([[0, 1, 2], [1, 2, 3]])
embedding = nn.Embedding(4, 3)
print(embedding(input))
>tensor([[[ 0.5663,  1.6136, -0.7869],
         [ 0.1662,  0.1890,  1.2086],
         [-1.1227, -0.0549, -0.8810]],

        [[ 0.1662,  0.1890,  1.2086],
         [-1.1227, -0.0549, -0.8810],
         [ 1.1553, -0.9264, -2.2916]]], grad_fn=<EmbeddingBackward0>)
  • 可以看到,Embedding 创建了一个 4x3 的矩阵,一共可以表示4个索引,每个表示都是一个包含3个数字的向量

包含 padding_idx

input = torch.tensor([[0, 1, 2], [1, 2, 3]])
embedding = nn.Embedding(4, 3, padding_idx=1)
print(embedding(input))
>tensor([[[-1.8803, -0.3651,  0.0161],
         [ 0.0000,  0.0000,  0.0000],
         [ 1.6608, -0.0528,  1.3493]],

        [[ 0.0000,  0.0000,  0.0000],
         [ 1.6608, -0.0528,  1.3493],
         [ 2.0451,  0.5492,  0.5848]]], grad_fn=<EmbeddingBackward0>)
  • padding_idx=1 的embedding vector为全 0

padding_idx在backward gradient时会被忽略,参考链接

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

长命百岁️

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值