torch.nn.Embedding函数

torch.nn.Embedding可将数据的离散表示转换成连续的表示。以NLP为例,若输入数据为中文句子,则每个汉字可表示成one-hot向量,但此方法若输入数据涉及的汉字数量过大,则输入维度可达几十万维,故可利用Embedding将输入数据进行嵌入表示。可将Embedding看作一个查询表(二维),一个汉字对应查询表的一行数据,则行数就代表所涉及汉字的总个数,列数则代表嵌入的维度(自己定)。
torch.nn.Embedding(num_embeddings, embedding_dim)的参数主要有num_embeddings和embedding_dim两个,其中num_embeddings代表查询表的大小(行数),embedding_dim代表嵌入维度。
import torch

embedding = torch.nn.Embedding(num_embeddings=4, embedding_dim=2) # 创建4x2的查询表
print(embedding.weight)
test = torch.randint(0, 4, (2, 2))   # 生成数据,维度为2x2
print(test)
test = embedding(test)               # 将数据进行嵌入,输出维度为2x2x2
print(test)

"""
Parameter containing:
tensor([[ 0.1208, -1.6480],
        [ 0.3138, -0.9083],
        [-1.0675,  0.0510],
        [-0.4471,  1.4563]], requires_grad=True)

tensor([[1, 3],
        [2, 2]])

tensor([[[ 0.3138, -0.9083],
         [-0.4471,  1.4563]],

        [[-1.0675,  0.0510],
         [-1.0675,  0.0510]]], grad_fn=<EmbeddingBackward0>)
"""

如图所示,经过Embedding后,生成的数据test中的1,3,2,2被查询表中行数索引为1,3,2,2的向量代替了。

注意:要进行嵌入的数据为整数(表示位置),且不能超过查询表的行数

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

吃冰442

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值