这篇论文将transformer机制运用到计算机视觉领域(主要是进行了图片分类),并且取得了不错的效果
其实整体思路挺简单的,就是将是图片拆分成很多小块,然后将小块排列成矩阵送入transformer encoder模块中计算,具体的计算过程如下图所示
我主要分享一下代码
import torch
import torch.nn as nn
import math
class MLP(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.):
super(MLP, self).__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, input):
output = self.net(input)
return output
class MSA(nn.Module):
"