pytorch学习笔记------手写Vitimage2embedding和模型原理

参考b站up主 deep_thoughts

import torch
import torch.nn as nn
import torch.nn.functional as F


# step1 convert image to embedding  vector sequence 
def image2emb_navie( image  , path_size , weight ):
    # image siez : batch size  * channel *  h  * w 
    # 将 image 分块, 且每个块之间 并没有重叠,因此 stride 与 path_size相等
    patch = F.unfold(image , kernel_size = path_size , stride = path_size  ).transpose(-1 , -2)
    patch_embedding = patch @ weight
    return patch_embedding 


def image2emb_conv(image , kernel , stride ):
    # 使用二维卷积 获取patch
    conv_output = F.conv2d(image , kernel , stride = stride )  # bs * oc * oh * ow 
    bs , oc, oh, ow = conv_output.shape
    # 对patch进行flatten
    patch_embedding = conv_output.reshape((bs ,oc, oh * ow )).transpose(-1 , -2)
    return patch_embedding


# test code for  image2emb 
bs , ic , image_h , image_w = 1 ,3 ,8,8
path_size = 4
model_dim = 8
max_num_token = 16
num_class = 10
label = torch.randint(10 , (bs,))
# 每个path包含的像素点的数目 
path_depth = path_size * path_size * ic 
weight = torch.randn(path_depth , model_dim)  # model_dim  is output channel   patch_depth  con2d h*w*input_channel 
image = torch.randn(bs, ic , image_h , image_w )
patch_embedding_navie = image2emb_navie(image , path_size , weight)
#print(patch_embedding_navie.shape)

kernel = weight.transpose(0 ,1 ).view( - 1 , ic , path_size , path_size)  # oc * ic * kh * kw 

patch_embedding_conv2d = image2emb_conv(image , kernel , path_size)

#print(patch_embedding_navie)
#print(patch_embedding_conv2d)

# step 2 prepared  CLS token embedding 
# requires_grad 可训练的
cls_token_embedding = torch.randn( model_dim , requires_grad= True)
""" print("patch_embedding_conv2d shape :"  , patch_embedding_conv2d.shape)
print("cls_token_embedding shape :"  , cls_token_embedding.shape)
cls_token_embedding = cls_token_embedding.view(1 ,1 ,model_dim)
print(" cls_token_embedding shape view 1 1 model_dim :" , cls_token_embedding.shape)

cls_token_embedding = cls_token_embedding.repeat(bs , 1 , 1 )
print(" cls_token_embedding shape repeat bs 1 1 :"  ,cls_token_embedding.shape )
token_embedding = torch.cat([cls_token_embedding , patch_embedding_conv2d] , dim = 1 )
print(" cls cat path :" , token_embedding) """

token_embedding = torch.cat([cls_token_embedding.view(1 , 1 , model_dim).repeat(bs ,1 ,1) , patch_embedding_conv2d ] , dim = 1 )

# step3 add position embedding 

position_embedding_table = torch.randn(max_num_token ,  model_dim , requires_grad= True)
seq_len = token_embedding.shape[1]
position_embedding =  torch.tile(position_embedding_table[:seq_len] , [token_embedding.shape[0] , 1 ,1])

token_embedding  += position_embedding
# step4 pass embedding to transformer encoder 
encoder_layer = nn.TransformerEncoderLayer(d_model = model_dim , nhead = 8)
transformer_encoder = nn.TransformerEncoder(encoder_layer , num_layers= 6)
encoder_output = transformer_encoder(token_embedding)

# step5 do classfication
cls_token_output = encoder_output[: , 0 , : ]
liner_layer =   nn.Linear(model_dim , num_class)
logits = liner_layer(cls_token_output)
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits , label)
print(loss)

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值