import tensorflow as tf
import pandas as pd
import matplotlib.pyplot as plt
train_dir = './training/training'
valid_dir = './validation/validation'
train_dataset = tf.keras.preprocessing.image.ImageDataGenerator(
rescale=1. / 255.,
horizontal_flip=True,
vertical_flip=True,
width_shift_range=0.2,
height_shift_range=0.2,
rotation_range=90,
shear_range=0.2,
zoom_range=0.2,
fill_mode='nearest'
)
valid_dataset = tf.keras.preprocessing.image.ImageDataGenerator(
rescale=1. / 255.
)
train_dataset = train_dataset.flow_from_directory(
directory=train_dir,
batch_size=32,
target_size=(224, 224),
class_mode='categorical',
shuffle=True
)
valid_dataset = valid_dataset.flow_from_directory(
directory=valid_dir,
batch_size=32,
target_size=(224, 224),
class_mode='categorical',
shuffle=False
)
class MyConv(tf.keras.layers.Layer):
def __init__(self, filters, kernel_size, strides, padding, **kwargs):
super().__init__(**kwargs)
self.conv = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation(activation='relu')
])
def call(self, inputs, **kwargs):
x = self.conv(inputs)
return x
class ResCellDown(tf.keras.layers.Layer):
def __init__(self, exp_channel, output_channel, **kwargs):
super().__init__(**kwargs)
self.branch1 = tf.keras.models.Sequential([
MyConv(filters=exp_channel, kernel_size=1, strides=1, padding='valid'),
MyConv(filters=exp_channel, kernel_size=3, strides=2, padding='same'),
tf.keras.layers.Conv2D(filters=output_channel, kernel_size=1, strides=1, padding='same'),
tf.keras.layers.BatchNormalization()
])
self.branch2 = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(filters=output_channel, kernel_size=1, strides=2, padding='same'),
tf.keras.layers.BatchNormalization()
])
self.add = tf.keras.layers.Add()
self.activation = tf.keras.layers.Activation(activation='relu')
def call(self, inputs, **kwargs):
branch1 = self.branch1(inputs)
branch2 = self.branch2(inputs)
branch = self.add([branch1, branch2])
x = self.activation(branch)
return x
class ResCell(tf.keras.layers.Layer):
def __init__(self, exp_channel, output_channel, **kwargs):
super().__init__(**kwargs)
self.branch1 = tf.keras.models.Sequential([
MyConv(filters=exp_channel, kernel_size=1, strides=1, padding='valid'),
MyConv(filters=exp_channel, kernel_size=3, strides=1, padding='same'),
tf.keras.layers.Conv2D(filters=output_channel, kernel_size=1, strides=1, padding='same'),
tf.keras.layers.BatchNormalization()
])
self.add = tf.keras.layers.Add()
self.activation = tf.keras.layers.Activation(activation='relu')
def call(self, inputs, **kwargs):
branch1 = self.branch1(inputs)
branch2 = inputs
branch = self.add([branch1, branch2])
x = self.activation(branch)
return x
class BasicCellDown(tf.keras.layers.Layer):
def __init__(self, output_channel, **kwargs):
super().__init__(**kwargs)
self.branch1 = tf.keras.models.Model([
MyConv(filters=output_channel, kernel_size=3, strides=2, padding='same'),
tf.keras.layers.Conv2D(filters=output_channel, strides=1, kernel_size=3, padding='same'),
tf.keras.layers.BatchNormalization()
])
self.branch2 = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(filters=output_channel, strides=2, kernel_size=1, padding='same'),
tf.keras.layers.BatchNormalization()
])
self.add = tf.keras.layers.Add()
self.activation = tf.keras.layers.Activation(activation='relu')
def call(self, inputs, **kwargs):
branch1 = self.branch1(inputs)
branch2 = self.branch2(inputs)
branch = self.add([branch1, branch2])
x = self.activation(branch)
return x
class BasicCell(tf.keras.layers.Layer):
def __init__(self, output_channel, **kwargs):
super().__init__(**kwargs)
self.branch1 = tf.keras.models.Model([
MyConv(filters=output_channel, kernel_size=3, strides=1, padding='same'),
tf.keras.layers.Conv2D(filters=output_channel, strides=1, kernel_size=3, padding='same'),
tf.keras.layers.BatchNormalization()
])
self.add = tf.keras.layers.Add()
self.activation = tf.keras.layers.Activation(activation='relu')
def call(self, inputs, **kwargs):
branch1 = self.branch1(inputs)
branch2 = inputs
branch = self.add([branch1, branch2])
x = self.activation(branch)
return x
res_net_structure = {'res18': [2, 2, 2, 2],
'res34': [3, 4, 6, 3],
'res50': [3, 4, 6, 3],
'res101': [3, 4, 23, 3],
'res152': [3, 8, 36, 3]}
def build_res_net(net_name, num_classes, n_height=224, n_width=224):
input_layer = tf.keras.layers.Input(shape=(n_height, n_width, 3), dtype=tf.float32)
x = MyConv(filters=64, kernel_size=7, strides=2, padding='same')(input_layer)
x = tf.keras.layers.MaxPool2D(pool_size=3, strides=2, padding='same')(x)
if net_name == 'res18' or net_name == 'rse34':
for i, layer_num in enumerate(res_net_structure[net_name]):
for j in range(i):
if j == 0 and i != 0:
x = BasicCellDown(64 * (i + 1))(x)
else:
x = BasicCell(64 * (i + 1))(x)
else:
for i, layer_num in enumerate(res_net_structure[net_name]):
for j in range(i):
if j == 0 and i != 0:
x = ResCellDown(64 * (i + 1), 256 * (i + 1))(x)
else:
x = ResCell(64 * (i + 1), 256 * (i + 1))(x)
x = tf.keras.layers.GlobalAvgPool2D()(x)
x = tf.keras.layers.Dense(num_classes)(x)
prediction = tf.keras.layers.Softmax()(x)
model = tf.keras.models.Model(inputs=input_layer, outputs=prediction)
return model
def draw_curve(history):
pd.DataFrame(history.history).plot()
plt.gca().set_ylim(0, 1)
plt.grid(True)
plt.show()
model34 = build_res_net('res34', 10)
model50 = build_res_net('res50', 10)
optimizer34 = tf.keras.optimizers.Adam(learning_rate=0.001)
optimizer50 = tf.keras.optimizers.Adam(learning_rate=0.001)
model34.compile(optimizer=optimizer34, loss='categorical_crossentropy', metrics=['acc'])
model50.compile(optimizer=optimizer50, loss='categorical_crossentropy', metrics=['acc'])
history34 = model34.fit(train_dataset,
steps_per_epoch=train_dataset.samples // 32,
validation_data=valid_dataset,
validation_steps=valid_dataset.samples // 32,
epochs=10)
history50 = model50.fit(train_dataset,
steps_per_epoch=train_dataset.samples // 32,
validation_data=valid_dataset,
validation_steps=valid_dataset.samples // 32,
epochs=10)
draw_curve(history34)
draw_curve(history50)
RES_NET代码
最新推荐文章于 2024-11-10 08:15:38 发布