桃叶儿尖上尖,柳絮儿飞满了天…
1 导入库
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
解释:其中einops库用于张量操作,增强代码的可读性,使用还是比较方便的。教程链接:
2 调用
if __name__=="__main__":
net = ViT(image_size=256,
patch_size=32,#pathces的尺寸
num_classes=1000,
dim=1024, #embddings的长度,也就是每个block的输入输出的尺寸
depth=6,#网络深度,多少个block
heads=16,#注意力抽头的个数
mlp_dim=2048,#mlp中反瓶颈结构的中间维度,也就是先升维,再降维
dropout=0.1,
emb_dropout=0.1)
x = torch.rand((2, 3, 256, 256))#测试数据
output = net(x)
从主干到分支解释代码。
3 ViT网络
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads,
mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2
assert