import torch
import torch.nn as nn
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
class TransModel(nn.Module):
def __init__(self, hidden_dim, num_head, num_layers):
super().__init__()
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_dim,
nhead=num_head
)
self.encoder = nn.TransformerEncoder(
encoder_layer,
num_layers=num_layers
)
self.lstm = nn.LSTM(
input_size=hidden_dim,
hidden_size=hidden_dim,
num_layers=1,
batch_first=True,
bidirectional=True # 表示双向
)
def forward(self, input):
'''
input: [batch_size, max_seq_len, hidden_dim]
'''
# [batch_size, max_seq_len, hidden_dim]
encoder_output = self.encoder(input)
# [batch_size, max_seq_len, hidden_dim*2]
lstm_output, _ = self.lstm(encoder_output)
return lstm_output
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hidden_dim = 128
model = TransModel(hidden_dim=hidden_dim, num_head=4, num_layers=4).to(device)
# 随机生成一个 batch=5, seq_len=20, hidden_dim=128的随机矩阵
# 如果是你的任务,你需要先把你的raw_text转成下面的inputs tensor
# 我不太确定你是怎么处理的。总之,只要inputs的hidden_dim与model的hidden_dim保持一致,代码就不会报错
inputs = torch.randn(5, 20, hidden_dim).to(device)
output = model(inputs)
print(output.shape)
transformer_lstm_example
最新推荐文章于 2024-07-02 21:42:18 发布