1. Inception
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPool2D, BatchNormalization, Activation, Dropout, GlobalAveragePooling2D
from tensorflow.keras import Model
class ConvBNRelu(Model):
def __init__(self, ch, kernel_size, strides =1, padding ='same'):
super(ConvBNRelu, self).__init__()
self.model =tf.keras.models.Sequential([
Conv2D(ch, kernel_size, strides = strides, padding = padding),
BatchNormalization(),
Activation('relu')
])
def call (self,x):
x = self.model(x)
class InceptionBlk(Model):
def __init__(self, ch, strides =1):
super (InceptionBlk, self).__init__()
self.ch = ch
self.strides = strides
self.c1 = ConvBNRelu(ch, kernel_size=1, strides=strides)
self.c2_1 = ConvBNRelu(ch, kernel_size=1, strides=strides)
self.c2_2 = ConvBNRelu(ch, kernel_size=3, strides=1)
self.c3_1 = ConvBNRelu(ch, kernel_size=1, strides=strides)
self.c3_2 = ConvBNRelu(ch, kernel_size=5, strides=1)
self.p4_1 = MaxPool2D(3,strides = 1, padding ='same')
self.c4_2 = ConvBNRelu(ch, kernel_size=1, strides=1)
def call(self,x):
x1 = self.c1(x)
x2_1 = self.c2_1(x)
x2_2 = self.c2_2(x2_1)
x3_1 = self.c3_1(x)
x3_2 = self.c3_2(x3_1)
x4_1 = self.p4_1(x)
x4_2 = self.c4_2(x4_1)
x = tf.concat([x1, x2_2, x3_2,x4_2], axis =3)#沿深度方向
return x
class Inception10(Model):
def __init__(self, num_blocks,num_classes, init_ch =16,**kwargs ):
super(Inception10, self).__init__(**kwargs)
self.in_channels = init_ch
self.out_channels = init_ch
self.num_blocks = num_blocks
self.init_ch = init_ch
self.c1 = ConvBNRelu(init_ch)
self.blocks = tf.keras.models.Sequential()
for block_id in range(num_blocks):
for layers_id in range(2):
if layer_id ==0:
block = InceptionBlk(self.out_channels, strides = 2)
else:
block = InceptionBlk(self.out_channels, strides =1)
self.blocks.add(block)
self.out_channels *=2
self.p1 = GlobalAveragePooling2D()
self.f1 = Dense(num_classes, activation = 'softmax')
def call(self, x):
x = self.c1(x)
x = self.blocks(x)
x = self.p1(x)
y = self.f1(x)
return y
2. ResNet
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPool2D, BatchNormalization, Activation, Dropout, GlobalAveragePooling2D
from tensorflow.keras import Model
class ResnetBlock(Model):
def __init__(self, filters, strides =1, residual_path = False):
super(ResnetBlock, self).__init__()
self.filters = filters
self.strides = strides
self.residual_path = residual_path
self.c1 = Conv2D(filters, (3,3), strides = strides, padding = 'same', use_bias = False)
self.b1 = BatchNormalization()
self.a1 = Activation('relu')
self.c2 = Conv2D(filters, (3,3), strides = 1, padding = 'same', use_bias = False)
self.b2 = BatchNormalization()
#如果维度不相同,1*1卷积操作
if self.residual_path:
self.down_c1 = Conv2D(filters, (1,1), strides = strides, padding = 'same', use_bias = False)
self.down_b1 = BatchNormalization()
self.a2 = Activation('relu')
def call(self, inputs):
residual = inputs
x = self.c1(inputs)
x = self.b1(x)
x = self.a1(x)
x = self.c2(x)
y = self.b2(x)
#维度不同的话,要先进行1*1卷积,使维度相同
if self.residual_path:
residual = self.down_c1(inputs)
residual = self.down_b1(residual)
out = self.a2(y+residual)
return out
class ResNet18(Model):
def __init__(self, block_list, initial_filters =64):
super(ResNet18, self).__init__()
self.num_blocks = len(block_list)
self.block_list = block_list
self.out_filters = initial_filters
self.c1 = Conv2D(self.out_filters, (3,3), strides =1,
padding ='same', use_bias = False, kernel_initializer = 'he_normal' )
self.b1 = BatchNormalization()
self.a1 = Activation('relu')
#四个Resnet模块
self.blocks = tf.keras.models.Sequential()
for block_id in range (self.num_blocks):
for layer_id in range (block_list[block_id]):
if block_id!=0 and layer_id ==0:#维度不同
block = ResnetBlock(self.out_filters, strides= 2, residual_path = True)
else :#维度相同
block = ResnetBlock(self.out_filters, residual_path= False)
self.blocks.add(block)
self.out_filters *=2
self.p1 = GlobalAveragePooling2D()
self.f1 = Dense(10)
def call(self, inputs):
x = self.c1(inputs)
x = self.b1(x)
x = self.a1(x)
x = self.blocks(x)
x = self.p1(x)
y = self.f1(x)
return y
model = ResNet18([2,2,2,2])