看官方的文档之后,自己的理解,供以后学习之用。
#输入的batch为2,每个batch有4个索引
input = torch.tensor([[1,2,4,5],[4,3,2,9]])
#字典中包含的词有10个,每个3维
embedding_matrix = torch.rand(10, 3)
F.embedding(input, embedding_matrix)
tensor([[[ 0.8490, 0.9625, 0.6753],
[ 0.9666, 0.7761, 0.6108],
[ 0.6246, 0.9751, 0.3618],
[ 0.4161, 0.2419, 0.7383]],
[[ 0.6246, 0.9751, 0.3618],
[ 0.0237, 0.7794, 0.0528],
[ 0.9666, 0.7761, 0.6108],
[ 0.3385, 0.8612, 0.1867]]])
所以最后得到了一个shape为(2,4,3)的一个tensor。