关于torch.nn.Embedding的浅显理解

最近在使用词嵌入向量表示我的数据标签,并且在试图理解torch.nn.Embedding函数。

函数提供一个简单的查找表,输入主要为词字典的大小和词嵌入的维度两个参数,输出为对应的词嵌入向量。

torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, _freeze=False, device=None, dtype=None)

词字典的大小num_embeddings限制了索引大小只能从0~ num_embeddings-1。num_embeddings(int) – size of the dictionary of embeddings,就是你给nn.Embedding函数的张量里的索引个数要在0~num_embeddings-1之间;embedding_dim (int) – the size of each embedding vector也即生成的词嵌入向量的最后一个维度。For example:

import torch.nn as nn
import torch

embedding = nn.Embedding(10, 3)

input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])

这里输入的向量input里包含的索引:1,2,3,4,5,9 均在[0,10)之间。embdding的第二个参数就决定了input的每一个数会被扩展到3维。所以最后生成的词嵌入维度如下,其中出现了两个2和两个4,因此索引出来了两个相同的词嵌入向量[-0.6431, 0.0748, 0.6969]和[ 1.4970, 1.3448, -0.9685]。

embedding(input)
        tensor([[[-0.0251, -1.6902,  0.7172],
                 [-0.6431,  0.0748,  0.6969],
                 [ 1.4970,  1.3448, -0.9685],
                 [-0.3677, -2.7265, -0.1685]],

                [[ 1.4970,  1.3448, -0.9685],
                 [ 0.4362, -0.4004,  0.9400],
                 [-0.6431,  0.0748,  0.6969],
                 [ 0.9124, -2.3616,  1.1151]]])
  • 9
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

starleeisamyth

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值