torch.nn.Embedding

详见 官方文档

class torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, device=None, dtype=None)

参数解读

  • num_embeddings (int) – 词嵌入矩阵的大小,即 有多少个 词嵌入向量 对应行
  • embedding_dim (int) – 每个词嵌入向量的维度 对应列
  • padding_idx (int, optional) – 可选参数。如果 指定该参数,则 padding_idx 将不会再 training 中更新,即将保持一个 固定向量不变

torch.nn.Embedding 作用

  • 是一个简单的查找表,用于 存储固定字典 和 大小 的嵌入
  • 该模块通常用于存储词嵌入使用索引检索它们

~Embedding.weight (Tensor) – 维度为 ( n u m _ e m b e d d i n g s , e m b e d d i n g _ d i m ) (num\_embeddings, embedding\_dim) (num_embeddings,embedding_dim) 的可学习权重矩阵,初始化为 N ( 0 , 1 ) N(0,1) N(0,1)

Embedding 输入、输出

  • 该模块的输入是索引列表,输出是相应的词嵌入
  • Input: ( ∗ ) (∗) (), 待词嵌入的 单词序列(IntTensor or LongTensor 类型)
  • Output: ( ∗ , H ) (∗,H) (,H),其中 * 是 input shape, H = e m b e d d i n g _ d i m H=embedding\_dim H=embedding_dim

例子

>>> # 嵌入矩阵大小 (10, 3),即 10 个向量,每个向量维度为 (3, 1)
>>> embedding = nn.Embedding(10, 3)
>>> # a batch of 2 samples of 4 indices each
>>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
>>> embedding(input)
tensor([[[-0.0251, -1.6902,  0.7172],
         [-0.6431,  0.0748,  0.6969],
         [ 1.4970,  1.3448, -0.9685],
         [-0.3677, -2.7265, -0.1685]],

        [[ 1.4970,  1.3448, -0.9685],
         [ 0.4362, -0.4004,  0.9400],
         [-0.6431,  0.0748,  0.6969],
         [ 0.9124, -2.3616,  1.1151]]])

>>> # 使用 padding_idx (这里用 0 填充 <pad> )
>>> embedding = nn.Embedding(10, 3, padding_idx=0)
>>> input = torch.LongTensor([[0,2,0,5]])
>>> embedding(input)
tensor([[[ 0.0000,  0.0000,  0.0000],
         [ 0.1535, -2.0309,  0.9315],
         [ 0.0000,  0.0000,  0.0000],
         [-0.1655,  0.9897,  0.0635]]])

# 测试
>>> embedding(torch.LongTensor([0]))
tensor([[0., 0., 0.]], grad_fn=<EmbeddingBackward>)
>>> embedding(torch.LongTensor([2]))
tensor([[ 0.1535, -2.0309,  0.9315]], grad_fn=<EmbeddingBackward>)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值