可见知乎
之前由于懒且不重视没思考embedding,故近期实战对其有些“误解”。害,可见这都是之前欠下的债啊,得补上!小白记录贴,仅供参考~
从pytorch源码里简单找了找,并没有找到对embedding有直观解释的代码,故转向tensorflow。
不管什么框架,原理得是一样的吧~对embedding追根溯源,发现主要包括两部分:
- 对input[batch_size, seq_len]进行one-hot编码[batch_size, vocab_size];
- 将one-hot编码后的矩阵和weight[vacab_size, embed_dim矩阵相乘;
复现代码如下:
利用pytorch给的接口可以得到embedding之后的值如下:
import torch.nn.functional as F
import torch
input = torch.tensor([[1,2,4,5]])
weights = torch.rand(10, 3)
F.embedding(input, weights)
'''
tensor([[[0.2776, 0.0587, 0.9897],
[0.9066, 0.3682, 0.0840],
[0.0370, 0.3854, 0.0091],
[0.5261, 0.5255, 0.1317]]])
'''
按照理解改写如下,得到相同的结果:
import numpy as np
np.matmul(tf.one_hot(input,depth=10),weights)
'''
tensor([[[0.9566, 0.8623, 0.8421],
[0.7956, 0.9499, 0.0336],
[0.4343, 0.6607, 0.8412],
[0.2082, 0.7314, 0.6296]]])
'''
2014年的论文text-CNN 有三种不同的embedding机制:rand/static/non static/,其中static利用训练好的word2vec向量,而non-static应该就是将embedding作为网络的一部分进行训练,这里的微调其实并不是input有所改变,而是embedding层之后的x_emb有改变!归根到底是fine-tune网络~
参考文献:
申小明77:tensorflow中的Embedding操作详解zhuanlan.zhihu.com
【python】np.dot()、np.multiply()、np.matmul()方法以及*和@运算符的用法总结_敲代码的quant的博客-CSDN博客blog.csdn.net