torch.nn.Embedding
在使用pytorch进行词嵌入使用torch.nn.Embedding()就可以做到
nn.Embedding在pytoch中的解释
class torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
norm_type=2, scale_grad_by_freq=False, sparse=False)
模块的输入是一个下标的列表,输出是对应的词嵌入。
参数:
- num_embeddings (int) - 嵌入字典的大小
- embedding_dim (int) - 每个嵌入向量的大小
- padding_idx (int, optional) - 如果提供的话,输出遇到此下标时用零填充
- max_norm (float, optional) - 如果提供的话,会重新归一化词嵌入,使它们的范数小于提供的值
- norm_type (float, optional) - 对于max_norm选项计算p范数时的p
- scale_grad_by_freq (boolean, optional) - 如果提供的话,会根据字典中单词频率缩放梯度
变量:
- weight (Tensor) -形状为(num_embeddings, embedding_dim)的模块中可学习的权值
形状:
- 输入: LongTensor (N, W), N = mini-batch, W = 每个mini-batch中提取的下标数
- 输出: (N, W, embedding_dim)
Embedding的使用
参数的初始化
import torch
import torch.nn as nn
vocab={'a':0,'b':1,'c':2}
#nn.Embedding(1000,50)表示有1000个词,嵌入到50维中
#对比若是one-hot,1000个词将会是1000维
embedd=nn.Embedding(3,5)
Embedd
以简单一维的向量为例
#输入:LongTensor (N, W), N = mini-batch, W = 每个mini-batch中提取的下标数
#输出: (N, W, embedding_dim)
a_idx=torch.LongTensor(vocab['a'])#输入类型为LongTensor a_idx=0
#a_idx在embedding处理后
a=embedd(a_idx)
print(a)
#输出:
#tensor([[ 0.9256, 1.1268, -0.5181, -0.2635, 1.9282]],grad_fn=<EmbeddingBackward>)#
我们输入二维的
i
n
p
u
t
input
input
输入:
L
o
n
g
T
e
n
s
o
r
(
n
,
w
)
LongTensor (n, w)
LongTensor(n,w),
n
n
n为样本的数量,
w
w
w为每个样本中词数
#两组,每组有三个词0 , 1 ,2
inputs=torch.LongTensor([[0,2,1,1],[1,2,1,0]])
outputs=embedd(inputs)
print(outputs)
输出: ( n , w , e m b e d d i n g − d i m ) (n, w, embedding-dim) (n,w,embedding−dim)
tensor([[[ 0.9256, 1.1268, -0.5181, -0.2635, 1.9282],
[-0.1351, 0.7759, 1.4697, 1.2225, 0.0390],
[ 1.2778, -1.8050, -0.7962, -1.5764, 0.0664],
[ 1.2778, -1.8050, -0.7962, -1.5764, 0.0664]],
[[ 1.2778, -1.8050, -0.7962, -1.5764, 0.0664],
[-0.1351, 0.7759, 1.4697, 1.2225, 0.0390],
[ 1.2778, -1.8050, -0.7962, -1.5764, 0.0664],
[ 0.9256, 1.1268, -0.5181, -0.2635, 1.9282]]],
grad_fn=<EmbeddingBackward>)
因为初始化时* embedd=nn.Embedding(3,5)*
所以是 0, 1, 2三个词嵌入
当超过时就会报错
当inputs中出现 3时
inputs=torch.LongTensor([[0,2,1,3],[1,2,1,3]])
outputs=embedd(inputs)
print(outputs)
#RuntimeError: index out of range: Tried to access index 3 out of table with 2 rows.#