tensorflow2实现vision transformer

import tensorflow as tf
from tensorflow.keras.layers import (Dense,Conv2D,LayerNormalization,
                                Layer,Dropout,Input,GlobalAveragePooling1D,)
from tensorflow.keras import Sequential,Model


class Identity(Layer):
    # usage: 
    # 首先实例化, attn = Identity()
    # 然后传入tensor, out = attn(a_tensor)
    def __init__(self):
        super().__init__()
    
    def call(self,inputs):
        return inputs


class PatchEmbedding(Layer):
    # imag_size=[224,224],in_channels=3, patch_size=7
    # embed_dim=16,
    def __init__(self,patch_size,embed_dim,dropout=0.):
        super().__init__()
        self.patch_embed = Conv2D(embed_dim,patch_size,patch_size)
        self.dropout = Dropout(dropout)

    def call(self,inputs):
        # [batch,224,224,3] -> [batch,32,32,16]
        x = self.patch_embed(inputs)

        # [batch,32,32,16] -> [batch,32*32,16]
        x = tf.reshape(x,shape=[x.shape[0],x.shape[1]*x.shape[2],x.shape[3]])

        x = self.dropout(x)
        
        return x

class MLP(Layer):
    def __init__(self,embed_dim, mlp_ratio=4.0,dropout=0.0):
        super().__init__()
        self.fc1 = Dense(int(embed_dim*mlp_ratio))
        self.fc2 = Dense(embed_dim)
        self.dropout = Dropout(rate=dropout)
    
    def call(self, inputs):
        # [batch,h,w,embed_dims] -> [batch,h,w,embed_dims*mlp_ratio]
        x = self.fc1(inputs)
        x = tf.nn.gelu(x) # 激活函数
        x = self.dropout(x)
        
        # [batch,h,w,embed_dims*mlp_ratio] -> [batch,h,w,embed_dims]
        x = self.fc2(x)
        x = self.dropout(x)

        return x

class Encoder(Layer):
    def __init__(self,embed_dims):
        super().__init__()
        self.atten = Identity() # TODO
        self.atten_norm = LayerNormalization()
        self.mlp = MLP(embed_dims)
        self.mlp_norm = LayerNormalization()

    def call(self,inputs):
        # [batch, h'*w', embed_dims] -> [batch, h'*w', embed_dims]
        h = inputs
        x = self.atten_norm(inputs) # 先做层标准化
        x = self.atten(x)
        x = x + h

        h = x
        x = self.mlp_norm(x)
        x = self.mlp(x)
        x = x + h

        return x

class ViT(Layer):
    def __init__(self,patch_size,embed_dims,encoder_length=5,num_classes=2):
        super().__init__()
        self.patch_embed = PatchEmbedding(patch_size=patch_size,embed_dim=embed_dims)
        # encoder list
        layer_list = []
        layer_list = [Encoder(embed_dims=embed_dims) for i in range(encoder_length)]
        self.encoders = Sequential(layer_list)
        self.head = Dense(num_classes)
        self.avgpool = GlobalAveragePooling1D()
        self.layernorm = LayerNormalization()
    
    def call(self,inputs):
        # [batch, h, w, embed_dims] -> [batch, h'*w', embed_dims]
        x = self.patch_embed(inputs)

        # 通过encoder_length层encoder
        x = self.encoders(x) 

        # layernorm, 对embed_dims维度做归一化
        x = self.layernorm(x)

        # [batch, h'*w', embed_dims] -> [batch,embed_dims]
        x = self.avgpool(x)

        # [batch, embed_dims] -> [batch, num_classes]
        x = self.head(x)

        return x



if __name__ == '__main__':
    inputs = Input(shape=(224,224,3),batch_size=4)
    vision_transformer = ViT(patch_size=7,embed_dims=16,encoder_length=5,num_classes=2)
    out = vision_transformer(inputs)
    model = Model(inputs=inputs,outputs=out,name='vit-tf2')
    model.summary()

写了一下框架,还没有实现attention,标记为TODO

Model: "vit-tf2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         [(4, 224, 224, 3)]        0
_________________________________________________________________
vi_t (ViT)                   (4, 2)                    13394
=================================================================
Total params: 13,394
Trainable params: 13,394
Non-trainable params: 0
_________________________________________________________________

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值