import tensorflow as tf
from tensorflow.keras.layers import (Dense,Conv2D,LayerNormalization,
Layer,Dropout,Input,GlobalAveragePooling1D,)
from tensorflow.keras import Sequential,Model
class Identity(Layer):
# usage:
# 首先实例化, attn = Identity()
# 然后传入tensor, out = attn(a_tensor)
def __init__(self):
super().__init__()
def call(self,inputs):
return inputs
class PatchEmbedding(Layer):
# imag_size=[224,224],in_channels=3, patch_size=7
# embed_dim=16,
def __init__(self,patch_size,embed_dim,dropout=0.):
super().__init__()
self.patch_embed = Conv2D(embed_dim,patch_size,patch_size)
self.dropout = Dropout(dropout)
def call(self,inputs):
# [batch,224,224,3] -> [batch,32,32,16]
x = self.patch_embed(inputs)
# [batch,32,32,16] -> [batch,32*32,16]
x = tf.reshape(x,shape=[x.shape[0],x.shape[1]*x.shape[2],x.shape[3]])
x = self.dropout(x)
return x
class MLP(Layer):
def __init__(self,embed_dim, mlp_ratio=4.0,dropout=0.0):
super().__init__()
self.fc1 = Dense(int(embed_dim*mlp_ratio))
self.fc2 = Dense(embed_dim)
self.dropout = Dropout(rate=dropout)
def call(self, inputs):
# [batch,h,w,embed_dims] -> [batch,h,w,embed_dims*mlp_ratio]
x = self.fc1(inputs)
x = tf.nn.gelu(x) # 激活函数
x = self.dropout(x)
# [batch,h,w,embed_dims*mlp_ratio] -> [batch,h,w,embed_dims]
x = self.fc2(x)
x = self.dropout(x)
return x
class Encoder(Layer):
def __init__(self,embed_dims):
super().__init__()
self.atten = Identity() # TODO
self.atten_norm = LayerNormalization()
self.mlp = MLP(embed_dims)
self.mlp_norm = LayerNormalization()
def call(self,inputs):
# [batch, h'*w', embed_dims] -> [batch, h'*w', embed_dims]
h = inputs
x = self.atten_norm(inputs) # 先做层标准化
x = self.atten(x)
x = x + h
h = x
x = self.mlp_norm(x)
x = self.mlp(x)
x = x + h
return x
class ViT(Layer):
def __init__(self,patch_size,embed_dims,encoder_length=5,num_classes=2):
super().__init__()
self.patch_embed = PatchEmbedding(patch_size=patch_size,embed_dim=embed_dims)
# encoder list
layer_list = []
layer_list = [Encoder(embed_dims=embed_dims) for i in range(encoder_length)]
self.encoders = Sequential(layer_list)
self.head = Dense(num_classes)
self.avgpool = GlobalAveragePooling1D()
self.layernorm = LayerNormalization()
def call(self,inputs):
# [batch, h, w, embed_dims] -> [batch, h'*w', embed_dims]
x = self.patch_embed(inputs)
# 通过encoder_length层encoder
x = self.encoders(x)
# layernorm, 对embed_dims维度做归一化
x = self.layernorm(x)
# [batch, h'*w', embed_dims] -> [batch,embed_dims]
x = self.avgpool(x)
# [batch, embed_dims] -> [batch, num_classes]
x = self.head(x)
return x
if __name__ == '__main__':
inputs = Input(shape=(224,224,3),batch_size=4)
vision_transformer = ViT(patch_size=7,embed_dims=16,encoder_length=5,num_classes=2)
out = vision_transformer(inputs)
model = Model(inputs=inputs,outputs=out,name='vit-tf2')
model.summary()
写了一下框架,还没有实现attention,标记为TODO
Model: "vit-tf2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(4, 224, 224, 3)] 0
_________________________________________________________________
vi_t (ViT) (4, 2) 13394
=================================================================
Total params: 13,394
Trainable params: 13,394
Non-trainable params: 0
_________________________________________________________________