tensorflow2实现vision transformer

从结构来看,主要需要实现:

  1. patch_embedding;包括image的embedding+一个分类头,以及pos_embedding
  2. muliHead_Self_Attention;也就是怎么得到q、k、v,以及它们怎么乘得到attention
  3. MLP

用这三个层就可以堆叠出一个transformer encoder,然后循环layer遍就可以了。

然后取出第一个分类头,经过MLP_head(其实就是一个Dense层)做分类就结束了。

让我们看代码吧:

import warnings
warnings.filterwarnings("ignore")
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf
from tensorflow.keras.layers import (Layer,Conv2D,LayerNormalization,
                                    Dense,Input,Dropout,Softmax,Add)

from tensorflow.keras.models import Model
from tensorflow.keras.activations import gelu



# patch_embedding层,包括图片的embedding+分类头, 加上pos_embedding
class PatchEmbedding(Layer):
    def __init__(self,image_size,patch_size,embed_dim,**kwargs):
        super(PatchEmbedding,self).__init__(**kwargs)

        self.embed_dim = embed_dim
        self.n_patches = (image_size//patch_size) * (image_size//patch_size)
        self.patch_embed = Conv2D(self.embed_dim,patch_size,patch_size)

        # 添加分类的token,会concat到image_tokens中,使得shape为[b,196+1,768]
        self.cls_token = self.add_weight('cls_token',shape=[1,1,self.embed_dim],
                                dtype='float32',initializer='random_normal',
                                trainable=True)
        # pos_embedding与(image_tokens+cls_token)相加,所以shape也必须为[b,197,768]
        self.pos_embeding = self.add_weight('pos_embedding',shape=[1,self.n_patches+1,self.embed_dim],
                                dtype='float32',initializer='random_normal',
                                trainable=True)
        
    def call(self,inputs):
        # patch_size=16, embed_dim=768
        # [b,224,224,3] -> [b,14,14,768]
        x = self.patch_embed(inputs)
        # [b,14,14,768] -> [b,196,768]
        b,h,w,_ = x.shape
        x = tf.reshape(x,shape=[b,h*w,self.embed_dim])
        # 1,1,768 -> b,1,768
        cls_tokens = tf.broadcast_to(self.cls_token,(b,1,self.embed_dim))
        # -> b, 197, 768
        x = tf.concat([x,cls_tokens],axis=1)

        # 加上pos_embedding -> b, 197, 728
        x = x + self.pos_embeding

        return x

    def get_config(self):
        config = super(PatchEmbedding, self).get_config()
        config.update({"embed_dim": self.embed_dim,
                        "num_patches":self.n_patches,
                        })
        return config


# msa层的实现
class multiHead_self_attention(Layer):
    def __init__(self,embed_dim,num_heads,attention_dropout=0.0,**kwargs):
        super(multiHead_self_attention,self).__init__(**kwargs)

        self.num_heads = num_heads
        self.head_dim = embed_dim // self.num_heads
        self.all_head_dim = self.num_heads * self.head_dim
        
        self.scale = self.head_dim ** (-0.5) # q*k之后的变换系数

        self.qkv = Dense(self.all_head_dim*3)
        self.proj = Dense(self.all_head_dim)

        self.attention_dropout = Dropout(attention_dropout)

        self.softmax = Softmax()
    
    def call(self,inputs):
        # -> b,197,768*3
        qkv = self.qkv(inputs)
        # q,k,v: b,197,768
        q,k,v = tf.split(qkv,3,axis=-1)
        
        b,n_patches,all_head_dim = q.shape
        # q,k,v: b,197,768 -> b,197,num_heads, head_dim 假设num_heads=12
        # b,197,768 -> b,197,12,64
        q = tf.reshape(q,shape=[b,n_patches,self.num_heads,self.head_dim])
        k = tf.reshape(k,shape=[b,n_patches,self.num_heads,self.head_dim])
        v = tf.reshape(v,shape=[b,n_patches,self.num_heads,self.head_dim])

        # b,197,12,64 -> b,12,197,64
        q = tf.transpose(q,[0,2,1,3])
        k = tf.transpose(k,[0,2,1,3])
        v = tf.transpose(v,[0,2,1,3])
        # -> b,12,12,64
        attention = tf.matmul(q,k,transpose_b=True)
        attention = self.scale * attention
        attention = self.softmax(attention)
        attention = self.attention_dropout(attention)
        # -> b,12,197,64
        out = tf.matmul(attention,v)
        # b,12,197,64 -> b,197,12,64
        out = tf.transpose(out,[0,2,1,3])
        # b,197,12,64 -> b,197,768
        out = tf.reshape(out,shape=[b,n_patches,all_head_dim])

        out = self.proj(out)
        return out

    def get_config(self):
        config = super(multiHead_self_attention, self).get_config()
        config.update({"num_heads": self.num_heads,
                        "head_dim":self.head_dim,
                        "all_head_dim":self.all_head_dim,
                        "scale":self.scale
                        })
        return config


class MLP(Layer):
    def __init__(self,embed_dim,mlp_ratio=4.0,dropout=0.0,**kwargs):
        super(MLP,self).__init__(**kwargs)
        self.embed_dim = embed_dim
        self.mlp_ratio = mlp_ratio
        self.dropout = dropout

    def call(self,inputs):
        # 1,197,768 -> 1,197,768*4
        x = Dense(int(self.embed_dim*self.mlp_ratio))(inputs)
        x = gelu(x)
        x = Dropout(self.dropout)(x)

        # 1,197,768*4 - 1,197,768
        x = Dense(self.embed_dim)(x)
        x = Dropout(self.dropout)(x)

        return x

    def get_config(self):
        config = super(MLP,self).get_config()
        config.update({
            "embed_dim":self.embed_dim,
            "mlp_ratio":self.mlp_ratio,
            "dropout":self.dropout
        })


def VisionTransformer(input_shape=[224,224,3],num_classes=5):
    image_size = 224
    num_heads = 12
    patch_size = 16
    embed_dim = 768
    layer_length = 12


    inputs = Input(shape=input_shape,batch_size=1)
    # 1,224,224,3 -> 1,197,768
    x = PatchEmbedding(image_size,patch_size,embed_dim,name='patchAndPos_embedding')(inputs)
    
    # 循环layer_length遍
    for i in range(1,layer_length+1):
        h = x
        x = LayerNormalization(name=f'LayerNorm{i}_1')(x)
        # 1,197,768 -> 1,197,768
        x = multiHead_self_attention(embed_dim,num_heads,0,name=f'MSA{i}')(x)
        # 1,197,768 -> 1,197,768
        x = Add(name=f'add{i}_1')([x,h])
        h = x

        x = LayerNormalization(name=f'LayerNorm{i}_2')(x)
        # 1,197,768 -> 1,197,768
        x = MLP(embed_dim,name=f'MLP{i}')(x)
        # 1,197,768 -> 1,197,768
        x = Add(name=f'add{i}_2')([x,h])

    # 1,197,768 -> 1,768
    cls_token = x[:,0] # 取出第1个token出来做分类
    # 1,768 -> 1, num_classes
    x = Dense(num_classes,name='classifier')(cls_token)
    out = Softmax()(x)

    model = Model(inputs=inputs,outputs=out,name='tf2-vit')

    return model


    


if __name__ == '__main__':
    input_shape = [224,224,3]
    num_classes = 5
    vitmodel = VisionTransformer(input_shape,num_classes)
    vitmodel.summary()

Model: "tf2-vit"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to
==================================================================================================
 input_1 (InputLayer)           [(1, 224, 224, 3)]   0           []

 patchAndPos_embedding (PatchEm  (1, 197, 768)       742656      ['input_1[0][0]']
 bedding)

 LayerNorm1_1 (LayerNormalizati  (1, 197, 768)       1536        ['patchAndPos_embedding[0][0]']
 on)

 MSA1 (multiHead_self_attention  (1, 197, 768)       2362368     ['LayerNorm1_1[0][0]']
 )

 add1_1 (Add)                   (1, 197, 768)        0           ['MSA1[0][0]',
                                                                  'patchAndPos_embedding[0][0]']

 LayerNorm1_2 (LayerNormalizati  (1, 197, 768)       1536        ['add1_1[0][0]']
 on)

 MLP1 (MLP)                     (1, 197, 768)        0           ['LayerNorm1_2[0][0]']

 add1_2 (Add)                   (1, 197, 768)        0           ['MLP1[0][0]',
                                                                  'add1_1[0][0]']

 LayerNorm2_1 (LayerNormalizati  (1, 197, 768)       1536        ['add1_2[0][0]']
 on)

 MSA2 (multiHead_self_attention  (1, 197, 768)       2362368     ['LayerNorm2_1[0][0]']
 )

 add2_1 (Add)                   (1, 197, 768)        0           ['MSA2[0][0]',
                                                                  'add1_2[0][0]']

 LayerNorm2_2 (LayerNormalizati  (1, 197, 768)       1536        ['add2_1[0][0]']
 on)

 MLP2 (MLP)                     (1, 197, 768)        0           ['LayerNorm2_2[0][0]']

 add2_2 (Add)                   (1, 197, 768)        0           ['MLP2[0][0]',
                                                                  'add2_1[0][0]']

 LayerNorm3_1 (LayerNormalizati  (1, 197, 768)       1536        ['add2_2[0][0]']
 on)

 MSA3 (multiHead_self_attention  (1, 197, 768)       2362368     ['LayerNorm3_1[0][0]']
 )

 add3_1 (Add)                   (1, 197, 768)        0           ['MSA3[0][0]',
                                                                  'add2_2[0][0]']

 LayerNorm3_2 (LayerNormalizati  (1, 197, 768)       1536        ['add3_1[0][0]']
 on)

 MLP3 (MLP)                     (1, 197, 768)        0           ['LayerNorm3_2[0][0]']

 add3_2 (Add)                   (1, 197, 768)        0           ['MLP3[0][0]',
                                                                  'add3_1[0][0]']

 LayerNorm4_1 (LayerNormalizati  (1, 197, 768)       1536        ['add3_2[0][0]']
 on)

 MSA4 (multiHead_self_attention  (1, 197, 768)       2362368     ['LayerNorm4_1[0][0]']
 )

 add4_1 (Add)                   (1, 197, 768)        0           ['MSA4[0][0]',
                                                                  'add3_2[0][0]']

 LayerNorm4_2 (LayerNormalizati  (1, 197, 768)       1536        ['add4_1[0][0]']
 on)

 MLP4 (MLP)                     (1, 197, 768)        0           ['LayerNorm4_2[0][0]']

 add4_2 (Add)                   (1, 197, 768)        0           ['MLP4[0][0]',
                                                                  'add4_1[0][0]']

 LayerNorm5_1 (LayerNormalizati  (1, 197, 768)       1536        ['add4_2[0][0]']
 on)

 MSA5 (multiHead_self_attention  (1, 197, 768)       2362368     ['LayerNorm5_1[0][0]']
 )

 add5_1 (Add)                   (1, 197, 768)        0           ['MSA5[0][0]',
                                                                  'add4_2[0][0]']

 LayerNorm5_2 (LayerNormalizati  (1, 197, 768)       1536        ['add5_1[0][0]']
 on)

 MLP5 (MLP)                     (1, 197, 768)        0           ['LayerNorm5_2[0][0]']

 add5_2 (Add)                   (1, 197, 768)        0           ['MLP5[0][0]',
                                                                  'add5_1[0][0]']

 LayerNorm6_1 (LayerNormalizati  (1, 197, 768)       1536        ['add5_2[0][0]']
 on)

 MSA6 (multiHead_self_attention  (1, 197, 768)       2362368     ['LayerNorm6_1[0][0]']
 )

 add6_1 (Add)                   (1, 197, 768)        0           ['MSA6[0][0]',
                                                                  'add5_2[0][0]']

 LayerNorm6_2 (LayerNormalizati  (1, 197, 768)       1536        ['add6_1[0][0]']
 on)

 MLP6 (MLP)                     (1, 197, 768)        0           ['LayerNorm6_2[0][0]']

 add6_2 (Add)                   (1, 197, 768)        0           ['MLP6[0][0]',
                                                                  'add6_1[0][0]']

 LayerNorm7_1 (LayerNormalizati  (1, 197, 768)       1536        ['add6_2[0][0]']
 on)

 MSA7 (multiHead_self_attention  (1, 197, 768)       2362368     ['LayerNorm7_1[0][0]']
 )

 add7_1 (Add)                   (1, 197, 768)        0           ['MSA7[0][0]',
                                                                  'add6_2[0][0]']

 LayerNorm7_2 (LayerNormalizati  (1, 197, 768)       1536        ['add7_1[0][0]']
 on)

 MLP7 (MLP)                     (1, 197, 768)        0           ['LayerNorm7_2[0][0]']

 add7_2 (Add)                   (1, 197, 768)        0           ['MLP7[0][0]',
                                                                  'add7_1[0][0]']

 LayerNorm8_1 (LayerNormalizati  (1, 197, 768)       1536        ['add7_2[0][0]']
 on)

 MSA8 (multiHead_self_attention  (1, 197, 768)       2362368     ['LayerNorm8_1[0][0]']
 )

 add8_1 (Add)                   (1, 197, 768)        0           ['MSA8[0][0]',
                                                                  'add7_2[0][0]']

 LayerNorm8_2 (LayerNormalizati  (1, 197, 768)       1536        ['add8_1[0][0]']
 on)

 MLP8 (MLP)                     (1, 197, 768)        0           ['LayerNorm8_2[0][0]']

 add8_2 (Add)                   (1, 197, 768)        0           ['MLP8[0][0]',
                                                                  'add8_1[0][0]']

 LayerNorm9_1 (LayerNormalizati  (1, 197, 768)       1536        ['add8_2[0][0]']
 on)

 MSA9 (multiHead_self_attention  (1, 197, 768)       2362368     ['LayerNorm9_1[0][0]']
 )

 add9_1 (Add)                   (1, 197, 768)        0           ['MSA9[0][0]',
                                                                  'add8_2[0][0]']

 LayerNorm9_2 (LayerNormalizati  (1, 197, 768)       1536        ['add9_1[0][0]']
 on)

 MLP9 (MLP)                     (1, 197, 768)        0           ['LayerNorm9_2[0][0]']

 add9_2 (Add)                   (1, 197, 768)        0           ['MLP9[0][0]',
                                                                  'add9_1[0][0]']

 LayerNorm10_1 (LayerNormalizat  (1, 197, 768)       1536        ['add9_2[0][0]']
 ion)

 MSA10 (multiHead_self_attentio  (1, 197, 768)       2362368     ['LayerNorm10_1[0][0]']
 n)

 add10_1 (Add)                  (1, 197, 768)        0           ['MSA10[0][0]',
                                                                  'add9_2[0][0]']

 LayerNorm10_2 (LayerNormalizat  (1, 197, 768)       1536        ['add10_1[0][0]']
 ion)

 MLP10 (MLP)                    (1, 197, 768)        0           ['LayerNorm10_2[0][0]']

 add10_2 (Add)                  (1, 197, 768)        0           ['MLP10[0][0]',
                                                                  'add10_1[0][0]']

 LayerNorm11_1 (LayerNormalizat  (1, 197, 768)       1536        ['add10_2[0][0]']
 ion)

 MSA11 (multiHead_self_attentio  (1, 197, 768)       2362368     ['LayerNorm11_1[0][0]']
 n)

 add11_1 (Add)                  (1, 197, 768)        0           ['MSA11[0][0]',
                                                                  'add10_2[0][0]']

 LayerNorm11_2 (LayerNormalizat  (1, 197, 768)       1536        ['add11_1[0][0]']
 ion)

 MLP11 (MLP)                    (1, 197, 768)        0           ['LayerNorm11_2[0][0]']

 add11_2 (Add)                  (1, 197, 768)        0           ['MLP11[0][0]',
                                                                  'add11_1[0][0]']

 LayerNorm12_1 (LayerNormalizat  (1, 197, 768)       1536        ['add11_2[0][0]']
 ion)

 MSA12 (multiHead_self_attentio  (1, 197, 768)       2362368     ['LayerNorm12_1[0][0]']
 n)

 add12_1 (Add)                  (1, 197, 768)        0           ['MSA12[0][0]',
                                                                  'add11_2[0][0]']

 LayerNorm12_2 (LayerNormalizat  (1, 197, 768)       1536        ['add12_1[0][0]']
 ion)

 MLP12 (MLP)                    (1, 197, 768)        0           ['LayerNorm12_2[0][0]']

 add12_2 (Add)                  (1, 197, 768)        0           ['MLP12[0][0]',
                                                                  'add12_1[0][0]']

 tf.__operators__.getitem (Slic  (1, 768)            0           ['add12_2[0][0]']
 ingOpLambda)

 classifier (Dense)             (1, 5)               3845        ['tf.__operators__.getitem[0][0]'
                                                                 ]

 softmax_12 (Softmax)           (1, 5)               0           ['classifier[0][0]']

==================================================================================================
Total params: 29,131,781
Trainable params: 29,131,781
Non-trainable params: 0
_________________________________________________________________________________________

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值