import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
# 关于word embedding ,以序列建模为例
batch_size = 2
# 单词表大小
max_num_src_words = 8
max_num_tgt_words = 8
model_dim = 8
# 序列的最大长度
max_src_seq_len = 5
max_tgt_seq_len = 5
max_position_len = 5
# src_len = torch.randint(2, 5, (batch_size,))
# tgt_len = torch.randint(2, 5, (batch_size,))
# 考虑source sentence和target sentence
# step:1构建序列,序列的字符以其在词表中的索引的形式表示
src_len = torch.Tensor([2, 4]).to(torch.int32)
tgt_len = torch.Tensor([4, 3]).to(torch.int32)
# 单词索引构成的句子:pad:默认值为0-->unsqueeze:将一维向量再增加一维->在第0维cat
src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L,)), (0, max(src_len) - L)), 0) \
for L in src_len])
tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_tgt_words, (L,)), (0, max(tgt_len) - L)), 0) \
for L in tgt_len])
# print(src_len)
# print(tgt_len)
# print("单词索引构成的源序列:\n", src_seq)
# print("单词索引构成的目标序列:\n", tgt_seq)
# step2:构造embedding
src_embedding_table = nn.Embedding(max_num_src_words + 1, model_dim)
tgt_embedding_table = nn.Embedding(max_num_tgt_words + 1, model_dim)
src_embedding = src_embedding_table(src_seq)
# print(src_embedding_table.weight)
# print(src_seq)
# print(src_embedding)
【Transformer】word embedding
最新推荐文章于 2024-07-25 22:25:44 发布