torch.nn.Embedding

torch.nn.Embedding

在我看来torch.nn.Embedding就是把离散值映射成连续值,事实上它也是这么操作的,其流程是离散值 → \rightarrow Onehot → \rightarrow 连续值。

背景

除了在自然语言处理中,我们日常的时序预测中也会用到torch.nn.Embedding, torch.nn.Embedding的输入是离散值,其适用范围包括但不限于年、月、日、时。

以月为例,月的取值范围就是[0,2,3,4,5,6,7,8,9,10,11,11],如果用Onehot编码那每一月就是一个12维的Onehot向量,假设dmodel=64, 则需要定义一个编码层

import torch
import torch.nn.functional as F 


embedding = torch.nn.Embedding(12,64)

input = torch.LongTensor([[1,2,3],[11,10,9]])
input_onehot = F.one_hot(input, num_classes=12)

out = embedding(input)

print(input.size())
print(input_onehot.size())
print(out.size())

输出结果是:
torch.Size([2, 3])
torch.Size([2, 3, 12])
torch.Size([2, 3, 64])

其实torch.nn.Embedding就是先将input进行Onehot编码,然后在做一个线性映射(你可以中这么理解,但不要这么实操,因为里面有一些dtype的问题),将离散值编码到一个预设的dmoel维的实数值向量空间。

更一般的,embedding = torch.nn.Embedding(V,D),其中V=vocabulary account(离散值集合的势), D=dmoel.

对于一个input(B,T), 其中B=batch_size, T=sequence_length。

对应的input_onehot (B,T,V)

Embedding之后的output(B,T,D)

References

  1. https://www.bilibili.com/video/BV1wm4y187Cr/?spm_id_from=333.337.search-card.all.click&vd_source=107a981693181bff01f387d4a6c314c9
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值