torch.nn.Embedding参数详解-之-num_embeddings,embedding_dim

torch.nn.Embedding()

1、关于torch.nn.Embedding()前两个参数的详解。

先贴函数全貌

torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, device=None, dtype=None

2、详解常用的前两个参数。常用的调用形式:nn.Embedding(10, 3)

平时调用的一般形式如下(官网的例子):

>>> # an Embedding module containing 10 tensors of size 3
>>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
tensor([[[-0.0251, -1.6902,  0.7172],
         [-0.6431,  0.0748,  0.6969],
         [ 1.4970,  1.3448, -0.9685],
         [-0.3677, -2.7265, -0.1685]],

        [[ 1.4970,  1.3448, -0.9685],
         [ 0.4362, -0.4004,  0.9400],
         [-0.6431,  0.0748,  0.6969],
         [ 0.9124, -2.3616,  1.1151]]])

2.1、第一个参数的来源:

其中nn.Embedding(10,3)中的10:10=maximum index + 1即input的最大值加1,上面input最大为9,所以这里的第一个参数为10.
实验一下将input的最大值改为10,结果报错。
在这里插入图片描述

2.2、第二个参数的来源:

其中nn.Embedding(10,3)中的3,表示是我们指定的nn.Embedding()输出的结果中每个向量(最里面的[])包含3个元素。如下图所示:
在这里插入图片描述

3、总结:

所以torch.nn.Embedding(num_embeddings, num_embeddings,省略)这两个必须要填的参数,num_embeddings是需要进行embedding的数据决定的,num_embeddings是我们自己决定的。

  • 3
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值