一、VIT简介
ViT是2020年Google团队提出的将Transformer应用在图像分类的模型 但是因为其模型“简单”且效果好,可扩展性强。
二、运行代码前服务器准备
代码需要在服务器上运行,因此我们需要准备一个服务器,并在服务器上配置好conda的虚拟环境和pytorch上安装好spikingjelly的包。
我们需要有一个服务器的账号,这里使用的xshell7(也可使用window自带的远程桌面进行服务器连接),本文以xshell7进行示范:
添加服务器连接:
在主机和端口分别输入其相应的服务器地址与端口号,然后点击“用户身份验证”,在其中输入分配的服务器个人账号。
然后连接即可。
注意:连接服务器可能需要使用到相应的vpn。
成功连接上服务器。
接下来需要在服务器上配置虚拟环境
注意:为了服务器安全,一般服务器关闭网络功能,所以需要我们使用scp将文件下载好后上传到服务器,这里我所使用的是与xshell配套的xftp进行文件传输。
详细过程可参考:
Ubuntu20.04安装pytorch(包括安装Anaconda和虚拟环境配置以及安装包spikingjelly)_火锅店的保安长的博客-CSDN博客_ubuntu安装pytorch
在配置好虚拟环境之后,需要再安装一个eniop包(操纵张量-通过灵活而强大的张量操作符提供易读并可靠的代码),即可运行VIT代码。
三、代码运行
在服务器上运行Vit代码:
四、代码详解
![](https://i-blog.csdnimg.cn/blog_migrate/a31f7dc89f2b6520ca4c01c9d3645b0c.png)
1.Add 残差块(类似于resNet结构,将forward(x)+x作为非线性输入)
#残差模块,放在每个前馈网络和注意力之后
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
作用:可以缓解梯度弥散和梯度破坏,解决网络退化等问题。
2.Normalize 归一化
#layernorm归一化,放在多头注意力层和激活函数层。用绝对位置编码的BERT,layernorm用来自身通道归一化
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim) # 只把最后一个纬度归一化 输入向量为 Batch * length * embedding
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
#放置多头注意力后,因为在于多头注意力使用的矩阵乘法为线性变换,后面跟上由全连接网络构成的FeedForward增加非线性结构
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim), # 先映射到高维
nn.GELU(), # 再非线性分类
nn.Linear(hidden_dim, dim) # 再映射回来
)
def forward(self, x):
return self.net(x)
进行线性变换,后对其进行非线性分类,得到其他更深层次的特征后,再映射回低维。
#多头注意力层,多个自注意力连起来。使用qkv计算
class Attention(nn.Module):
def __init__(self, dim, heads=8):
super().__init__()
self.heads = heads
self.scale = dim ** -0.5 # SoftMax的归一化尺度
self.to_qkv = nn.Linear(dim, dim * 3, bias=False) # qkv矩阵一起计算
self.to_out = nn.Linear(dim, dim) #
def forward(self, x, mask = None):
b, n, _, h = *x.shape, self.heads # 第一个表示将[B,N,E]分解成参数 (_ 其实等于dim,不改变输入向量长度)
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv=3, h=h) # rearrange 原来linear后是b * # n * (3 * dim * h)打散并且变成3 * batch * heads *// n * d
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale # 计算概率矩阵,前面是矩阵点乘,保留bh地维度id*dj = ij ,计算出来是b * h 个n * n的概率矩阵
if mask is not None:
mask = F.pad(mask.flatten(1), (1, 0), value = True)
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
mask = mask[:, None, :] * mask[:, :, None]
dots.masked_fill_(~mask, float('-inf'))
del mask
attn = dots.softmax(dim=-1) # 最后一维SoftMax
out = torch.einsum('bhij,bhjd->bhid', attn, v) # b h n dim
out = rearrange(out, 'b h n d -> b n (h d)') # concat heads
out = self.to_out(out) # 这个linear层是不是有点问题
return out #
5.构建原始transformer
#将图像切割成一个个图像块,组成序列化的数据输入Transformer执行图像分类任务。
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels=3):
super().__init__()
assert image_size % patch_size == 0 ,' image dimensions must be divisible by the patch size' # 警告内容
num_patches = (image_size // patch_size) ** 2 # 4*4 相当于一段话分成16个字,一个字用49长度编码
patch_dim = channels * patch_size ** 2 # 1 * 7^2 相当于多少个(句子通道)并行
self.patch_size = patch_size
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) # 位置编码,竟然是随机数编码位置!第一个参数不是batch,图片先一张一张编码?可训练的位置编码?
self.patch_to_embedding = nn.Linear(patch_dim, dim) # 49->64 竟然用线性来embedding
self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) # 随机数分类标签编码
self.transformer = Transformer(dim, depth, heads, mlp_dim) # 64 6 8 128
self.to_cls_token = nn.Identity()
self.mlp_head = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Linear(mlp_dim, num_classes)
) # 最终分类
def forward(self, img, mask=None):
p = self.patch_size #
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1) # 扩充到批次图片了
x = torch.cat((cls_tokens, x), dim=1) # 在第一维度合并
x += self.pos_embedding # 加上位置编码
x = self.transformer(x, mask)
x = self.to_cls_token(x[:, 0]) # 所有批次的第二维第一个为class token
return self.mlp_head(x)
五、运行结果
准确率达到了98.57%。