Inception module
InceptionNet 结构
code
import tensorflow as tf
import os
import numpy as np
import datetime
from tensorflow.keras.layers import Conv2D,BatchNormalization,Activation,MaxPool2D,GlobalAveragePooling2D,Dropout,Flatten,Dense
from tensorflow.keras import layers,Model
np.set_printoptions(threshold=np.inf
)
cifar10=tf.keras.datasets.cifar10
(x_train,y_train),(x_test,y_test)=cifar10.load_data()
x_train,x_test=x_train/255.0,x_test/255.0
class ConvBNA(Model):
def __init__(self,ch,kernelsize=3,strides=1,padding='same'):
super().__init__()
self.model=tf.keras.models.Sequential([
layers.Conv2D(ch,kernelsize,strides=strides,padding=padding,activation='relu'),
layers.BatchNormalization()
])
def call(self,x):
x=self.model(x)
return x
class InceptionBlock(Model):
def __init__(self,ch,strides=1):
super().__init__()
self.ch=ch
self.strides=strides
self.c1=ConvBNA(ch,kernelsize=1,strides=strides)
self.c2_1=ConvBNA(ch,kernelsize=1,strides=strides)
self.c2_2=ConvBNA(ch,strides=strides)
self.c3_1=ConvBNA(ch,kernelsize=1,strides=strides)
self.c3_2=ConvBNA(ch,kernelsize=5,strides=strides)
self.p4_1=MaxPool2D(pool_size=(3,3))
self.c4_2=ConvBNA(ch,kernelsize=1,strides=strides)
def call(self,x):
x1=self.c1(x)
x2_1=self.c2_1(x)
x2_2=self.c2_2(x2_1)
x3_1=self.c3_1(x)
x3_2=self.c3_2(x3_1)
x4_1=self.p4_1(x)
x4_2=self.c4_2(x4_1)
x=tf.concat([x1,x2_2,x3_2,x4_2],axis=3)
return x
class InceptionNet(Model):
def __init__(self,num_classes,init_ch=16,**kwargs):
super().__init__()
self.in_channels=init_ch
self.input_layer=ConvBNA(init_ch)
self.middle_layers=[
InceptionBlock(init_ch,strides=2),
InceptionBlock(init_ch,strides=1),
InceptionBlock(init_ch*2,strides=2),
InceptionBlock(init_ch*2,strides=1),
GlobalAveragePooling2D()
]
self.output_layer=Dense(num_classes,activation='softmax')
def call(self,x):
x=self.input_layer(x)
for layer in self.middle_layers:
x=layer(x)
y=self.output_layer(x)
return y
model=InceptionNet(num_classes=10)
checkpoint_save_path="./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path+'.index'):
print('--------------load the model----------')
model.load_weights(checkpoint_save_path)
cp_callback=tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
save_weights_only=True,
save_best_only=True)
model.compile(optimizer="adam",loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])
history=model.fit(x_train,y_train,batch_size=32,epochs=1,validation_data=(x_test,y_test),validation_freq=1,callbacks=[cp_callback])
print(model.trainable_variables)
file=open('./weights.txt','w')
for v in model.trainable_variables:
file.write(str(v.name)+'\n')
file.write(str(v.shape)+'\n')
file.write(str(v.numpy)+'\n')
file.close()