torch.nn.functional.embedding的参数理解,尤其是weight

本文介绍了PyTorch中的词嵌入函数torch.nn.functional.embedding,详细解释了其参数input和weight的作用,并通过实例展示了如何将索引映射到对应的词向量。该函数用于将输入的索引转换为预训练的词向量,以进行文本表示和后续的深度学习任务。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

介绍一下我们常用的嵌入函数torch.nn.functional.embedding,先看一下参数:torch.nn.functional.embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False)。我们常使用的就是前两个参数。input是在词向量矩阵中的索引列表,词向量矩阵,行数为最大可能的索引数+1,列数为词向量的维度。那么具体是什么含义呢?
input参数是我们想要用向量表示的对象(可以是文本、图等)的索引,weight储存了我们的向量表示,这个函数的目的就是输出一个索引和向量的对应关系。看一个例子:

import numpy as np
import torch.nn.functional as F
import torch

input = torch.tensor([[1,2,4,5],[4,3,2,8]])
embedding_matrix = torch.tensor(
        [[0.7330, 0.9718, 0.9023],
        [0.7769, 0.7640, 0.3664],
        [0.6036, 0.3873, 0.5681],
        [0.8422, 0.6275, 0.5400],
        [0.0346, 0.5622, 0.2547],
        [0.4926, 0.9282, 0.1762],
        [0.0037, 0.5831, 0.4443],
        [0.2001, 0.1086, 0.0518],
        [0.6574, 0.9185, 0.3451]])
print(F.embedding(input, embedding_matrix))

结果:

tensor([[[0.7769, 0.7640, 0.3664],
         [0.6036, 0.3873, 0.5681],
         [0.0346, 0.5622, 0.2547],
         [0.4926, 0.9282, 0.1762]],

        [[0.0346, 0.5622, 0.2547],
         [0.8422, 0.6275, 0.5400],
         [0.6036, 0.3873, 0.5681],
         [0.6574, 0.9185, 0.3451]]])

注意:

  1. 在变量input中第一个元素是1,那么在输出中,对应weight中下标为1的向量,就是weight[1],其值是[0.7769, 0.7640, 0.3664],其他的元素以此类推。
  2. weight的行数是可能的最大检索数+1的原因是,input的元素需要访问下标为index的数值。在例子中,最大的所以数为8,故weight 的行数至少为9,当然可以大于9.
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值