transformer模型现在已经广泛应用于NLP、CV等各种场景并且取得很好的效果,在此记录一下如何使用pytorch来构建Transformer模型进行分类,具体代码如下:
import torch
import numpy as np
import torch.nn as nn
from configs.config import opt
class trans_model(nn.Module):
def __init__(self, d_model,nhead,num_layers):
super(trans_model, self).__init__()
encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.linear = nn.Linear(d_model,2) # 此处的2代表最终输出2维
self.num_labels = 2
def forward(self, inputs):
inputs += PositionalEncoding(max_seq_len=128,embed_dim=512,inputs=inputs)
trans_out = self.transformer_encoder(inputs)
linear_out = self.linear(trans_out)
return linear_out
def PositionalEncoding(max_seq_len, embed_dim,inputs):
positional_encoding = np.array([[
[np.sin(pos / np.power(10000, 2 * i / embed_dim)) if i % 2 == 0 else
np.cos(pos / np.power(10000, 2 * i / embed_dim))
for i in range(embed_dim)]
for pos in range(max_seq_len)] for i in range(inputs.shape[0])])
return torch.tensor(positional_encoding)