【Transformer】word embedding

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# 关于word embedding ,以序列建模为例
batch_size = 2
# 单词表大小
max_num_src_words = 8
max_num_tgt_words = 8

model_dim = 8

# 序列的最大长度
max_src_seq_len = 5
max_tgt_seq_len = 5
max_position_len = 5
# src_len = torch.randint(2, 5, (batch_size,))
# tgt_len = torch.randint(2, 5, (batch_size,))

# 考虑source sentence和target sentence
# step:1构建序列,序列的字符以其在词表中的索引的形式表示
src_len = torch.Tensor([2, 4]).to(torch.int32)
tgt_len = torch.Tensor([4, 3]).to(torch.int32)
# 单词索引构成的句子:pad:默认值为0-->unsqueeze:将一维向量再增加一维->在第0维cat
src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L,)), (0, max(src_len) - L)), 0) \
                     for L in src_len])
tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_tgt_words, (L,)), (0, max(tgt_len) - L)), 0) \
                     for L in tgt_len])
# print(src_len)
# print(tgt_len)

# print("单词索引构成的源序列:\n", src_seq)
# print("单词索引构成的目标序列:\n", tgt_seq)

# step2:构造embedding

src_embedding_table = nn.Embedding(max_num_src_words + 1, model_dim)
tgt_embedding_table = nn.Embedding(max_num_tgt_words + 1, model_dim)

src_embedding = src_embedding_table(src_seq)

# print(src_embedding_table.weight)
# print(src_seq)
# print(src_embedding)
  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值