import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpersdefpair(t):return t ifisinstance(t,tuple)else(t, t)
x = torch.rand(1,256,256,3)
model = FeedForward(dim=3, hidden_dim=9)
y = model(x)
y.shape
torch.Size([1, 256, 256, 3])
Attention
classAttention(nn.Module):def__init__(self, dim, heads =8, dim_head =64, dropout =0.):super().__init__()
inner_dim = dim_head * heads
project_out =not(heads ==1and dim_head == dim)
self.heads = heads
self.scale = dim_head **-0.5
self.attend = nn.Softmax(dim =-1)
self.to_qkv = nn.Linear(dim, inner_dim *3, bias =False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout))if project_out else nn.Identity()defforward(self, x):
b, n, _, h =*x.shape, self.heads
tmp = self.to_qkv(x)# print(tmp.shape)
qkv = self.to_qkv(x).chunk(3, dim=-1)# print(qkv[0].shape)# print(qkv[1].shape)# print(qkv[2].shape)
q, k, v =map(lambda t: rearrange(t,'b n (h d) -> b h n d', h = h), qkv)# print(q.shape)# print(k.shape)
dots = einsum('b h i d, b h j d -> b h i j', q, k)* self.scale
# print(dots.shape)
attn = self.attend(dots)# print(attn.shape)
out = einsum('b h i j, b h j d -> b h i d', attn, v)# print(out.shape)
out = rearrange(out,'b h n d -> b n (h d)')# print(out.shape)return self.to_out(out)
classTransformer(nn.Module):def__init__(self, dim, depth, heads, dim_head, mlp_dim, dropout =0.):super().__init__()
self.layers = nn.ModuleList([])for _ inrange(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))]))defforward(self, x):for attn, ff in self.layers:
x = attn(x)+ x
x = ff(x)+ x
return x
VIT
classViT(nn.Module):def__init__(self,*, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool ='cls', channels =3, dim_head =64, dropout =0., emb_dropout =0.):super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)assert image_height % patch_height ==0and image_width % patch_width ==0,'Image dimensions must be divisible by the patch size.'
num_patches =(image_height // patch_height)*(image_width // patch_width)
patch_dim = channels * patch_height * patch_width
assert pool in{'cls','mean'},'pool type must be either cls (cls token) or mean (mean pooling)'
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
nn.Linear(patch_dim, dim),)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches +1, dim))
self.cls_token = nn.Parameter(torch.randn(1,1, dim))
self.dropout = nn.Dropout(emb_dropout)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
self.pool = pool
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes))defforward(self, img):print('input_size: ', img.shape)
x = self.to_patch_embedding(img)print('after patch_embedding: ', x.shape)
b, n, _ = x.shape
print('cls_token_size: ', x.shape)
cls_tokens = repeat(self.cls_token,'() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)print('after cat: ', x.shape)
x += self.pos_embedding[:,:(n +1)]
x = self.dropout(x)print('before transformer: ', x.shape)
x = self.transformer(x)print('after transformer: ', x.shape)
x = x.mean(dim =1)if self.pool =='mean'else x[:,0]print(x.shape)
x = self.to_latent(x)print(x.shape)return self.mlp_head(x)
x = torch.rand(1,3,256,256)
model = ViT(
image_size=(256,256),
patch_size=(8,8),
num_classes=8,
dim=128,
depth=6,
heads=8,
mlp_dim=128,)
y = model(x)
y.shape
input_size: torch.Size([1, 3, 256, 256])
after patch_embedding: torch.Size([1, 1024, 128])
cls_token_size: torch.Size([1, 1024, 128])
after cat: torch.Size([1, 1025, 128])
before transformer: torch.Size([1, 1025, 128])
after transformer: torch.Size([1, 1025, 128])
torch.Size([1, 128])
torch.Size([1, 128])
torch.Size([1, 8])