手写ViT

# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import logging
import math

from os.path import join as pjoin

import torch
import torch.nn as nn
import numpy as np

from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair
from scipy import ndimage

import models.configs as configs

from modeling_resnet import ResNetV2

logger = logging.getLogger(__name__)

def swish(x):
    return x*torch.sigmoid(x)

ACT2FN={"gelu":torch.nn.functional.gelu,"relu":torch.nn.functional.relu,"swish":swish}
import ml_collections

def get_b16_config():
    """Returns the ViT-B/16 configuration."""
    config = ml_collections.ConfigDict()
    config.patches = ml_collections.ConfigDict({'size': (16, 16)})
    config.hidden_size = 768
    config.transformer = ml_collections.ConfigDict()
    config.transformer.mlp_dim = 3072
    config.transformer.num_heads = 12
    config.transformer.num_layers = 12
    config.transformer.attention_dropout_rate = 0.0
    config.transformer.dropout_rate = 0.1
    config.classifier = 'token'
    config.representation_size = None
    return config

class Embeddings(nn.Module):
    '''
        根据图像块位置和位置嵌入构造输入的嵌入
    '''
    def __init__(self,config,img_size,in_channels=3):
        super(Embeddings,self).__init__()
        #图像原本大小
        img_size=_pair(img_size)
        #图像分块大小
        patch_size=_pair(config.patches['size'])
        #图像分块的个数
        print(img_size)
        print(patch_size)
        n_patches=(img_size[0]//patch_size[0])*(img_size[1]//patch_size[1])

        # 进行图像分块 patch_embeddings
        # 输入维度是[b,3,224,224],输出维度是[b,768,14,14]
        self.patch_embeddings=Conv2d(in_channels=in_channels,
                                     out_channels=config.hidden_size,
                                     kernel_size=patch_size,
                                     stride=patch_size)
        # 位置编码[b,1,197,768]
        self.position_embeddings=nn.Parameter(torch.zeros(1,n_patches+1,config.hidden_size))

        #分类token
        self.cls_token=nn.Parameter(torch.zeros(1,1,config.hidden_size))

        self.dropout=Dropout(config.transformer['dropout_rate'])
    def forward(self,x):
        # [B,3,224,224]
        # batch_size大小
        B=x.shape[0]

        #对cls_token进行广播操作,使得每个batch都有位置编码
        # expand(B,-1,-1)表示在第一维度进行广播,第二第三维度保持不变
        cls_tokens=self.cls_token.expand(B,-1,-1)

        # 对输入的原始图片进行分块
        # [B,768,14,14]
        x=self.patch_embeddings(x)

        # 将图像的空间维度展平
        # flatten(start_dim=0, end_dim=-1)
        # [B,768,196]
        x=x.flatten(2)

        #[B,196,768]
        x=x.transpose(-1,-2)
        # 沿着第二个维度进行拼接
        # [B,197,768]
        x=torch.cat((cls_tokens,x),dim=1)

        #加上位置编码
        embeddings= x+self.position_embeddings
        embeddings=self.dropout(embeddings)
        return embeddings

class Attention(nn.Module):
    def __init__(self,config,vis):
        super(Attention,self).__init__()
        self.vis=vis
        # 注意力的头数
        self.num_attention_heads=config.transformer['num_heads']
        # 每个注意力头的维度大小,从原始维度等比例划分
        self.attention_head_size=int(config.hidden_size/self.num_attention_heads)

        self.all_head_size=self.num_attention_heads*self.attention_head_size

        self.query=Linear(config.hidden_size,self.all_head_size)
        self.key=Linear(config.hidden_size,self.all_head_size)
        self.value=Linear(config.hidden_size,self.all_head_size)

        self.out = Linear(config.hidden_size,config.hidden_size)
        self.attn_dropout= Dropout(config.transformer["attention_dropout_rate"])
        self.proj_dropout=Dropout(config.transformer["attention_dropout_rate"])
        # 对最后一个维度进行softmax
        self.softmax=Softmax(dim=-1)

    def transpose_for_scores(self,x):
        new_x_shape=x.size()[:-1] +(self.num_attention_heads,self.attention_head_size)
        x=x.view(*new_x_shape)

        return x.permute(0,2,1,3)
    def forward(self,hidden_states):
        # [B, 197, 768]
        # 计算QKV
        mixed_query_layer=self.query(hidden_states)
        mixed_key_layer=self.key(hidden_states)
        mixed_value_layer=self.value(hidden_states)

        #[B, 12,197, 64]
        # 划分多头注意力
        query_layer=self.transpose_for_scores(mixed_query_layer)
        key_layer=self.transpose_for_scores(mixed_key_layer)
        value_layer=self.transpose_for_scores(mixed_value_layer)

        # Q×K操作,计算value的权重
        # [B,12,197,197]
        attention_scores=torch.matmul(query_layer,key_layer.transpose(-1,-2))
        attention_scores=attention_scores / math.sqrt(self.attention_head_size)
        attention_probs=self.softmax(attention_scores)
        weights=attention_probs if self.vis else None

        # 计算value,[B,12,197,64]
        context_layer=torch.matmul(attention_probs,value_layer)
        # [B,197,12,64]
        context_layer=context_layer.permute(0,2,1,3).contiguous()
        # [B,197,768]
        new_context_layer_shape=context_layer.size()[:-2]+(self.all_head_size,)
        context_layer=context_layer.view(*new_context_layer_shape)

        attention_outpuet=self.out(context_layer)
        attention_outpuet=self.proj_dropout(attention_outpuet)
        return attention_outpuet,weights

class MLP(nn.Module):
    def __init__(self,config):
        super(MLP,self).__init__()
        self.fc1=Linear(config.hidden_size,config.transformer["mlp_dim"])
        self.fc2=Linear(config.transformer["mlp_dim"],config.hidden_size)
        self.act_fn=ACT2FN['gelu']
        self.dropout=Dropout(config.transformer["dropout_rate"])
        #初始化权重参数
        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc1.bias,std=1e-6)
        nn.init.normal_(self.fc2.bias,std=1e-6)
    def forward(self,x):
        # [B,197,768]
        # [B,197,3072]
        x=self.fc1(x)
        x=self.act_fn(x)
        x=self.dropout(x)
        # [B,197,768]
        x=self.fc2(x)
        x=self.dropout(x)
        return x


class Block(nn.Module):
    def __init__(self,config,vis):
        super(Block,self).__init__()
        self.hidden_size=config.hidden_size
        self.attention_norm=LayerNorm(config.hidden_size,eps=1e-6)
        self.ffn_norm=LayerNorm(config.hidden_size,eps=1e-6)
        self.ffn=MLP(config)
        self.attn=Attention(config,vis)

    def forward(self,x):
        h=x
        x=self.attention_norm(x)
        x,weight=self.attn(x)
        x=x+h

        h=x
        x=self.ffn_norm(x)
        x=self.ffn(x)
        x=x+h
        return x,weight

class Encoder(nn.Module):
    def __init__(self,config,vis):
        super(Encoder,self).__init__()
        self.vis=vis
        self.layer=nn.ModuleList()
        self.encoder_norm=LayerNorm(config.hidden_size,eps=1e-6)
        for _ in range(config.transformer["num_layers"]):
            layer=Block(config, vis)
            self.layer.append(copy.deepcopy(layer))

    def forward(self,hidden_states):
        attn_weights=[]
        for layer_block in self.layer:
            hidden_states,weights=layer_block(hidden_states)
            if self.vis:
                attn_weights.append(weights)
        encoded=self.encoder_norm(hidden_states)
        return encoded,attn_weights
class Transformer(nn.Module):
    def __init__(self,config,img_size,vis):
        super(Transformer,self).__init__()
        self.embeddings=Embeddings(config,img_size=img_size)
        self.encoder=Encoder(config, vis)

    def forward(self,input_ids):
        embeddings_output=self.embeddings(input_ids)
        encoded,attn_weights=self.encoder(embeddings_output)
        return encoded,attn_weights

class VisionTransformer(nn.Module):
    def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.zero_head = zero_head
        self.classifier = config.classifier

        self.transformer = Transformer(config, img_size, vis)
        self.head = Linear(config.hidden_size, num_classes)

    def forward(self, x, labels=None):
        x, attn_weights = self.transformer(x)
        # print(x.shape)
        logits = self.head(x[:, 0])
        # print(logits.shape)

        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
            return loss
        else:
            return logits, attn_weights

if __name__ == '__main__':
    config = get_b16_config()
    # embeddings=Embeddings(config,img_size=224,in_channels=3)
    # x=torch.randn(64,3,224,224)
    # out=embeddings(x)
    # print(out.shape)
    x= torch.randn(64,3 ,224, 224)
    label=torch.zeros(64,)
    img_size=224
    ViT=VisionTransformer(config,img_size)
    print(ViT)
    logits,attn_weights=ViT(x)
    print(logits.shape)
    print(logits[0])
    print(torch.argmax(logits[0])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值