参考b站up主 deep_thoughts
import torch
import torch.nn as nn
import torch.nn.functional as F
# step1 convert image to embedding vector sequence
def image2emb_navie( image , path_size , weight ):
# image siez : batch size * channel * h * w
# 将 image 分块, 且每个块之间 并没有重叠,因此 stride 与 path_size相等
patch = F.unfold(image , kernel_size = path_size , stride = path_size ).transpose(-1 , -2)
patch_embedding = patch @ weight
return patch_embedding
def image2emb_conv(image , kernel , stride ):
# 使用二维卷积 获取patch
conv_output = F.conv2d(image , kernel , stride = stride ) # bs * oc * oh * ow
bs , oc, oh, ow = conv_output.shape
# 对patch进行flatten
patch_embedding = conv_output.reshape((bs ,oc, oh * ow )).transpose(-1 , -2)
return patch_embedding
# test code for image2emb
bs , ic , image_h , image_w = 1 ,3 ,8,8
path_size = 4
model_dim = 8
max_num_token = 16
num_class = 10
label = torch.randint(10 , (bs,))
# 每个path包含的像素点的数目
path_depth = path_size * path_size * ic
weight = torch.randn(path_depth , model_dim) # model_dim is output channel patch_depth con2d h*w*input_channel
image = torch.randn(bs, ic , image_h , image_w )
patch_embedding_navie = image2emb_navie(image , path_size , weight)
#print(patch_embedding_navie.shape)
kernel = weight.transpose(0 ,1 ).view( - 1 , ic , path_size , path_size) # oc * ic * kh * kw
patch_embedding_conv2d = image2emb_conv(image , kernel , path_size)
#print(patch_embedding_navie)
#print(patch_embedding_conv2d)
# step 2 prepared CLS token embedding
# requires_grad 可训练的
cls_token_embedding = torch.randn( model_dim , requires_grad= True)
""" print("patch_embedding_conv2d shape :" , patch_embedding_conv2d.shape)
print("cls_token_embedding shape :" , cls_token_embedding.shape)
cls_token_embedding = cls_token_embedding.view(1 ,1 ,model_dim)
print(" cls_token_embedding shape view 1 1 model_dim :" , cls_token_embedding.shape)
cls_token_embedding = cls_token_embedding.repeat(bs , 1 , 1 )
print(" cls_token_embedding shape repeat bs 1 1 :" ,cls_token_embedding.shape )
token_embedding = torch.cat([cls_token_embedding , patch_embedding_conv2d] , dim = 1 )
print(" cls cat path :" , token_embedding) """
token_embedding = torch.cat([cls_token_embedding.view(1 , 1 , model_dim).repeat(bs ,1 ,1) , patch_embedding_conv2d ] , dim = 1 )
# step3 add position embedding
position_embedding_table = torch.randn(max_num_token , model_dim , requires_grad= True)
seq_len = token_embedding.shape[1]
position_embedding = torch.tile(position_embedding_table[:seq_len] , [token_embedding.shape[0] , 1 ,1])
token_embedding += position_embedding
# step4 pass embedding to transformer encoder
encoder_layer = nn.TransformerEncoderLayer(d_model = model_dim , nhead = 8)
transformer_encoder = nn.TransformerEncoder(encoder_layer , num_layers= 6)
encoder_output = transformer_encoder(token_embedding)
# step5 do classfication
cls_token_output = encoder_output[: , 0 , : ]
liner_layer = nn.Linear(model_dim , num_class)
logits = liner_layer(cls_token_output)
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits , label)
print(loss)