8.16笔记,VIT

import torch
import torch.nn as nn
import torch.nn.functional as F

def image2emb_naive(image, patch_size, weight):
    # image shape: bs*channel*h*w
    patch = F.unfold(image, kernel_size=patch_size, stride=patch_size).transpose(-1,-2)
    patch_embedding = patch @ weight
    return patch_embedding

def image2emb_conv(image, kernel, stride):
    conv_output = F.conv2d(image, kernel, stride=stride) #bs*oc*oh*ow
    bs, oc, oh, ow = conv_output.shape
    patch_embedding = conv_output.reshape(bs, oc, oh*ow).transpose(-1,-2)
    return patch_embedding

# test code for image2emb
bs, ic, image_h, image_w = 1,3,8,8
patch_size = 4
model_dim = 8
max_num_token = 16
num_classes = 10
label = torch.randint(10,(bs,))

patch_depth = patch_size*patch_size*ic
image = torch.randn(bs, ic, image_h, image_w)
weight = torch.randn(patch_depth, model_dim) #model_dim是输出通道数目,patch_depth是卷积核的面积乘以输入通道数

patch_embedding_naive = image2emb_naive(image, patch_size, weight) #分块方法得到embedding
kernel = weight.transpose(0,1).reshape(-1, ic, patch_size, patch_size) #oc*ic*kh*kv

patch_embedding_conv = image2emb_conv(image, kernel, patch_size) #二维卷积的方法得到embedding

print(patch_embedding_conv)
print(patch_embedding_naive)

## step2
cls_token_embedding = torch.randn(bs, 1, model_dim,requires_grad=True)
token_embedding = torch.cat([cls_token_embedding, patch_embedding_conv], dim=1)

## step3
position_embedding_table = torch.randn(max_num_token, model_dim, requires_grad=True)
seq_len = token_embedding.shape[1]
position_embedding = torch.tile(position_embedding_table[:seq_len],[token_embedding.shape[0],1,1])
token_embedding += position_embedding

## step4
encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=8)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
encoder_output = transformer_encoder(token_embedding)

## step5
cls_token_output = encoder_output[:, 0, :]
linear_layer = nn.Linear(model_dim, num_classes)
logits = linear_layer(cls_token_output)
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, label)
print(loss)

tensor([[[ -3.6616,   1.2795,   8.1758,   3.0673,  -7.8565,  -5.1824, -13.5376,
           -0.1308],
         [  6.9238,  -7.3512,  -1.2001,  -4.0659,  -2.5937,  -5.9933, -10.0656,
            0.6723],
         [ 11.3471,   7.7629,  14.7286,  -6.5203,  -9.3314,  -0.9771,  -3.2062,
           -5.0264],
         [  6.5818,   1.7215,  -3.0544,   5.1548,  14.1692,   8.5659,   5.4677,
            1.7021]]])
tensor([[[ -3.6616,   1.2795,   8.1758,   3.0673,  -7.8565,  -5.1824, -13.5376,
           -0.1308],
         [  6.9238,  -7.3512,  -1.2001,  -4.0659,  -2.5937,  -5.9933, -10.0656,
            0.6723],
         [ 11.3471,   7.7629,  14.7286,  -6.5203,  -9.3314,  -0.9771,  -3.2062,
           -5.0264],
         [  6.5818,   1.7215,  -3.0544,   5.1548,  14.1692,   8.5659,   5.4677,
            1.7021]]])
tensor(2.8527, grad_fn=<NllLossBackward0>)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值