PyTorch 21. PyTorch中nn.Embedding模块

PyTorch 21. PyTorch中nn.Embedding模块

torch.nn.Embedding

函数

torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False)

参数解释

  1. num_embeddings: 查询表的大小
  2. embedding_dim: 每个查询向量的维度

函数大概解释:一个保存了固定字典和大小的简单查找表。这个模块常用来保存词嵌入和用下标检索他们。模块的输入是一个下标的列表,输出是对应的词嵌入。相当于随机生成了一个tensor,可以把它看作一个查询表,其size为[num_embeddings, embedding_dim]。其中num_embeddings是查询表的大小,embedding_dim是每个查询向量的维度。

通俗来讲,这是一个矩阵类,里面初始化了一个随机矩阵,矩阵的长时字典的大小,宽是用来表示字典中每个元素的属性向量,向量的维度根据你想要表示的元素的复杂度而定。类实例化之后可以根据字典中元素的下标来查找元素对应的向量。输入下标0,输出就是embeds矩阵中第0行。

具体实例

  1. 创建查询矩阵并使用它做Embedding:
embedding = nn.Embedding(5,3) #定义一个具有5个单词,维度为3的查询矩阵
print(embedding.weight) #展示该矩阵的具体内容
test = torch.LongTensor([[0, 2, 0, 1],
						[1, 3, 4, 4]]) #该test矩阵用于被embed,其size为[2,4]
# 其中的第一行为[0, 2, 0, 1], 表示获取查询矩阵中ID为0, 2, 0, 1的查询向量
# 可以在之后的test输出中与embed的输出进行比较
test = embedding(test)
print(test.size()) #输出embed后test的size, 为[2, 4, 3],增加的3,是因为查询向量的维度为3
print(test) # 输出embed后的test的内容

输出:
Parameter containing:
tensor([[-1.8056,  0.1836, -1.4376],
        [ 0.8409,  0.1034, -1.3735],
        [-1.3317,  0.8350, -1.7235],
        [ 1.5251, -0.2363, -0.6729],
        [ 0.4148, -0.0923,  0.2459]], requires_grad=True)
torch.Size([2, 4, 3])
tensor([[[-1.8056,  0.1836, -1.4376],
         [-1.3317,  0.8350, -1.7235],
         [-1.8056,  0.1836, -1.4376],
         [ 0.8409,  0.1034, -1.3735]],

        [[ 0.8409,  0.1034, -1.3735],
         [ 1.5251, -0.2363, -0.6729],
         [ 0.4148, -0.0923,  0.2459],
         [ 0.4148, -0.0923,  0.2459]]], grad_fn=<EmbeddingBackward>)			

可以看出,embedding相当于用一个矩阵的索引去查询原始嵌入矩阵中元素的特征。

  1. 寻找查询矩阵中特定ID(词)的查询向量(词向量):
# 访问某个ID,即第N个词的查询向量
print(embedding(torch.LongTensor([3]))) # 这里表示查询第3个词的词向量

输出:
tensor([[-1.6016, -0.8350, -0.7878]], grad_fn=<EmbeddingBackward>)
  1. 输出的hello这个词的word embedding
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

word_to_ix = {'hello':0, 'world':1}
embeds = nn.Embedding(2,5)
hello_idx = torch.LongTensor([word_to_ix['hello']])
hello_idx = Variable(hello_idx)
hello_embed = embeds(hello_idx)
print(hello_embed)

输出结果:

Variable containing:
 0.4606  0.6847 -1.9592  0.9434  0.2316
[torch.FloatTensor of size 1x5]
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值