torch.nn.Embedding
在我看来torch.nn.Embedding就是把离散值映射成连续值,事实上它也是这么操作的,其流程是离散值 → \rightarrow →Onehot → \rightarrow →连续值。
背景
除了在自然语言处理中,我们日常的时序预测中也会用到torch.nn.Embedding, torch.nn.Embedding的输入是离散值,其适用范围包括但不限于年、月、日、时。
以月为例,月的取值范围就是[0,2,3,4,5,6,7,8,9,10,11,11],如果用Onehot编码那每一月就是一个12维的Onehot向量,假设dmodel=64, 则需要定义一个编码层
import torch
import torch.nn.functional as F
embedding = torch.nn.Embedding(12,64)
input = torch.LongTensor([[1,2,3],[11,10,9]])
input_onehot = F.one_hot(input, num_classes=12)
out = embedding(input)
print(input.size())
print(input_onehot.size())
print(out.size())
输出结果是:
torch.Size([2, 3])
torch.Size([2, 3, 12])
torch.Size([2, 3, 64])
其实torch.nn.Embedding就是先将input进行Onehot编码,然后在做一个线性映射(你可以中这么理解,但不要这么实操,因为里面有一些dtype的问题),将离散值编码到一个预设的dmoel维的实数值向量空间。
更一般的,embedding = torch.nn.Embedding(V,D),其中V=vocabulary account(离散值集合的势), D=dmoel.
对于一个input(B,T), 其中B=batch_size, T=sequence_length。
对应的input_onehot (B,T,V)
Embedding之后的output(B,T,D)
References
- https://www.bilibili.com/video/BV1wm4y187Cr/?spm_id_from=333.337.search-card.all.click&vd_source=107a981693181bff01f387d4a6c314c9