基于ResNet50以及SELayer注意力机制的海洋鱼类识别系统,主要通过ResNet50对鱼类进行有关的识别,其中增加了SELayer注意力机制使其增加深度学习的准确性,最后通过PyQt5的库进行了前端的展示。
ResNet50代码:
from tensorflow.keras import layers, Model, Sequential
class BasicBlock(layers.Layer):
expansion = 1
def __init__(self, out_channel, strides=1, downsample=None, **kwargs):
super(BasicBlock, self).__init__(**kwargs)
self.conv1 = layers.Conv2D(out_channel, kernel_size=3, strides=strides,
padding="SAME", use_bias=False)
self.bn1 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)
# -----------------------------------------
self.conv2 = layers.Conv2D(out_channel, kernel_size=3, strides=1,
padding="SAME", use_bias=False)
self.bn2 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)
# -----------------------------------------
self.downsample = downsample
self.relu = layers.ReLU()
self.add = layers.Add()
def call(self, inputs, training=False):
identity = inputs
if self.downsample is not None:
identity = self.downsample(inputs)
x = self.conv1(inputs)
x = self.bn1(x, training=training)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x, training=training)
x = self.add([identity, x])
x = self.relu(x)
return x
class Bottleneck(layers.Layer):
expansion = 4
def __init__(self, out_channel, strides=1, downsample=None, **kwargs):
super(Bottleneck, self).__init__(**kwargs)
self.conv1 = layers.Conv2D(out_channel, kernel_size=1, use_bias=False, name="conv1")
self.bn1 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5, name="conv1/BatchNorm")
# -----------------------------------------
self.conv2 = layers.Conv2D(out_channel, kernel_size=3, use_bias=False,
strides=strides, padding="SAME", name="conv2")
self.bn2 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5, name="conv2/BatchNorm")
# -----------------------------------------
self.conv3 = layers.Conv2D(out_channel * self.expansion, kernel_size=1, use_bias=False, name="conv3")
self.bn3 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5, name="conv3/BatchNorm")
# -----------------------------------------
self.relu = layers.ReLU()
self.downsample = downsample
self.add = layers.Add()
def call(self, inputs, training=False):
identity = inputs
if self.downsample is not None:
identity = self.downsample(inputs)
x = self.conv1(inputs)
x = self.bn1(x, training=training)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x, training=training)
x = self.relu(x)
x = self.conv3(x)
x = self.bn3(x, training=training)
x = self.add([x, identity])
x = self.relu(x)
return x
def _make_layer(block, in_channel, channel, block_num, name, strides=1):
downsample = None
if strides != 1 or in_channel != channel * block.expansion:
downsample = Sequential([
layers.Conv2D(channel * block.expansion, kernel_size=1, strides=strides,
use_bias=False, name="conv1"),
layers.BatchNormalization(momentum=0.9, epsilon=1.001e-5, name="BatchNorm")
], name="shortcut")
layers_list = []
layers_list.append(block(channel, downsample=downsample, strides=strides, name="unit_1"))
for index in range(1, block_num):
layers_list.append(block(channel, name="unit_" + str(index + 1)))
return Sequential(layers_list, name=name)
def _resnet(block, blocks_num, im_width=224, im_height=224, num_classes=245, include_top=True):
# tensorflow中的tensor通道排序是NHWC
# (None, 224, 224, 3)
input_image = layers.Input(shape=(im_height, im_width, 3), dtype="float32")
x = layers.Conv2D(filters=64, kernel_size=7, strides=2,
padding="SAME", use_bias=False, name="conv1")(input_image)
x = layers.BatchNormalization(momentum=0.9, epsilon=1e-5, name="conv1/BatchNorm")(x)
x = layers.ReLU()(x)
x = layers.MaxPool2D(pool_size=3, strides=2, padding="SAME")(x)
x = _make_layer(block, x.shape[-1], 64, blocks_num[0], name="block1")(x)
x = _make_layer(block, x.shape[-1], 128, blocks_num[1], strides=2, name="block2")(x)
x = _make_layer(block, x.shape[-1], 256, blocks_num[2], strides=2, name="block3")(x)
x = _make_layer(block, x.shape[-1], 512, blocks_num[3], strides=2, name="block4")(x)
if include_top:
x = layers.GlobalAvgPool2D()(x) # pool + flatten
x = layers.Dense(num_classes, name="logits")(x)
predict = layers.Softmax()(x)
else:
predict = x
model = Model(inputs=input_image, outputs=predict)
return model
def resnet18(im_width=224, im_height=224, num_classes=245, include_top=True):
return _resnet(BasicBlock, [2, 2, 2, 2], im_width, im_height, num_classes, include_top)
def resnet34(im_width=224, im_height=224, num_classes=245, include_top=True):
return _resnet(BasicBlock, [3, 4, 6, 3], im_width, im_height, num_classes, include_top)
def resnet50(im_width=224, im_height=224, num_classes=245, include_top=False):
return _resnet(Bottleneck, [3, 4, 6, 3], im_width, im_height, num_classes, include_top)
def resnet101(im_width=224, im_height=224, num_classes=245, include_top=True):
return _resnet(Bottleneck, [3, 4, 23, 3], im_width, im_height, num_classes, include_top)
将上面的ResNet形成model.py文件并且导入到运行代码中,运行代码如下,主要把路径改为自己的就行:
from tensorflow.keras import layers, Model, Sequential
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
from time import *
from model import resnet50
import json
from attention import cbam_block
# from tensorflow.keras import layers, models, Input
# from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout
_R_MEAN = 123.68
_G_MEAN = 116.78
_B_MEAN = 103.94
def pre_function(img):
# img = im.open('test.jpg')
# img = np.array(img).astype(np.float32)
img = img - [_R_MEAN, _G_MEAN, _B_MEAN]
return img
train_image_generator = ImageDataGenerator(horizontal_flip=True,
preprocessing_function=pre_function, validation_split=0.2)
validation_image_generator = ImageDataGenerator(preprocessing_function=pre_function, validation_split=0.2)
def data_load(data_dir, img_height, img_width, batch_size):
train_ds = train_image_generator.flow_from_directory(
data_dir,
class_mode='categorical',
shuffle=True,
subset="training",
seed=123,
color_mode="rgb",
target_size=(img_height, img_width),
batch_size=batch_size)
val_ds = validation_image_generator.flow_from_directory(
data_dir,
class_mode='categorical',
shuffle=True,
subset="validation",
seed=123,
color_mode="rgb",
target_size=(img_height, img_width),
batch_size=batch_size)
class_names = train_ds.class_indices
class_indices = train_ds.class_indices
# transform value and key of dict
inverse_dict = dict((val, key) for key, val in class_indices.items())
json_str = json.dumps(inverse_dict, indent=4)
with open('models/class_indices.json', 'w') as json_file:
json_file.write(json_str)
return train_ds, val_ds, class_names
class SELayer(tf.keras.Model):
def __init__(self, filters, reduction=16):
super(SELayer, self).__init__()
self.filters = filters
self.reduction = reduction
self.GAP = tf.keras.layers.GlobalAveragePooling2D()
self.FC = tf.keras.models.Sequential([
tf.keras.layers.Dense(units=self.filters // self.reduction, input_shape=(self.filters, )),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Dense(units=filters),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('sigmoid')
])
self.Multiply = tf.keras.layers.Multiply()
def call(self, inputs, training=None, mask=None):
x = self.GAP(inputs)
x = self.FC(x)
x = self.Multiply([x, inputs])
return x
def build_graph(self, input_shape):
input_shape_without_batch = input_shape[1:]
self.build(input_shape)
inputs = tf.keras.Input(shape=input_shape_without_batch)
_ = self.call(inputs)
def model_load(IMG_SHAPE=(224, 224, 3), class_num=6):
feature = resnet50(im_width=224, im_height=224, num_classes=class_num, include_top=False)
feature.trainable = False
pre_weights_path = './pretrain_weights.ckpt'
feature.load_weights(pre_weights_path)
model = tf.keras.Sequential([
feature,
SELayer(2048),
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(1024),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(class_num, activation='softmax')
]
)
model.summary()
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
return model
def show_loss_acc(history):
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()), 1])
plt.title('Training and Validation Accuracy')
plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.savefig('results_reset_last2_fish6.png', dpi=100)
def train(epochs):
begin_time = time()
train_ds, val_ds, class_names = data_load("C:/Users/lenovo/Desktop/fish_image23/fish_image", 224, 224, 16)
print(class_names)
model = model_load(class_num=len(class_names))
history = model.fit(train_ds, validation_data=val_ds, epochs=epochs)
model.save_weights("models/resnet_last2_fish6_epoch5.ckpt", save_format='tf')
end_time = time()
run_time = end_time - begin_time
print('该循环程序运行时间:', run_time, "s")
show_loss_acc(history)
if __name__ == '__main__':
train(epochs=5)
将运行出来的结果使用前端展示出来,代码如下:
import tensorflow as tf
from PyQt5.QtGui import *
from PyQt5.QtCore import *
from PyQt5.QtWidgets import *
import sys
import cv2
from PIL import Image
import numpy as np
import shutil
from model import resnet50
import glob
class MainWindow(QTabWidget):
def __init__(self):
super().__init__()
self.setWindowIcon(QIcon('images/img.png'))
self.setWindowTitle('海洋鱼类分类系统')
# self.model = tf.keras.models.load_model("models/lenet_245_epoch5.h5")
self.to_predict_name = "images/fish1.png"
self.class_names = ['Amphiprion clarkia', 'Chaetodon lunulatus', 'Chaetodon trifascialis', 'Chromis chrysura', 'Dascyllus reticulatus', 'Pletrogly-phidodon dickii']
self.resize(900, 700)
self.initUI()
def initUI(self):
main_widget = QWidget()
main_layout = QHBoxLayout()
font = QFont('楷体', 15)
left_widget = QWidget()
left_layout = QVBoxLayout()
img_title = QLabel("样本")
img_title.setFont(font)
img_title.setAlignment(Qt.AlignCenter)
self.img_label = QLabel()
img_init = cv2.imread(self.to_predict_name)
h, w, c = img_init.shape
scale = 400 / h
img_show = cv2.resize(img_init, (0, 0), fx=scale, fy=scale)
cv2.imwrite("images/show.png", img_show)
img_init = cv2.resize(img_init, (224, 224))
cv2.imwrite('images/target.png', img_init)
self.img_label.setPixmap(QPixmap("images/show.png"))
left_layout.addWidget(img_title)
left_layout.addWidget(self.img_label, 1, Qt.AlignCenter)
# left_layout.setAlignment(Qt.AlignCenter)
left_widget.setLayout(left_layout)
right_widget = QWidget()
right_layout = QVBoxLayout()
btn_change = QPushButton(" 上传图片 ")
btn_change.clicked.connect(self.change_img)
btn_change.setFont(font)
btn_predict = QPushButton(" 开始识别 ")
btn_predict.setFont(font)
btn_predict.clicked.connect(self.predict_img)
label_result = QLabel(' 海洋鱼类名称 ')
self.result = QLabel("等待识别")
label_result.setFont(QFont('楷体', 24))
self.result.setFont(QFont('楷体', 24))
right_layout.addWidget(label_result, 0, Qt.AlignCenter)
right_layout.addStretch(0)
right_layout.addWidget(self.result, 0, Qt.AlignCenter)
right_layout.addWidget(btn_change)
right_layout.addWidget(btn_predict)
right_layout.addStretch()
right_widget.setLayout(right_layout)
# 关于页面
about_widget = QWidget()
about_layout = QVBoxLayout()
about_title = QLabel('欢迎使用海洋鱼类分类系统')
about_title.setFont(QFont('楷体', 18))
about_title.setAlignment(Qt.AlignCenter)
about_img = QLabel()
about_img.setPixmap(QPixmap('images/about1.png'))
about_img.setAlignment(Qt.AlignCenter)
label_super = QLabel('<a href="https://blog.csdn.net/qq_46136833?type=blog">鱼类识别设计</a>')
label_super.setFont(QFont('楷体', 12))
label_super.setOpenExternalLinks(True)
label_super.setAlignment(Qt.AlignRight)
about_layout.addWidget(about_title)
about_layout.addStretch()
about_layout.addWidget(about_img)
about_layout.addStretch()
about_layout.addWidget(label_super)
about_widget.setLayout(about_layout)
main_layout.addWidget(left_widget)
main_layout.addWidget(right_widget)
main_widget.setLayout(main_layout)
self.addTab(main_widget, '主页')
self.addTab(about_widget, '关于')
self.setTabIcon(0, QIcon('images/主页面.png'))
self.setTabIcon(1, QIcon('images/关于.png'))
def change_img(self):
openfile_name = QFileDialog.getOpenFileName(self, 'chose files', '', 'Image files(*.jpg *.png *jpeg)')
img_name = openfile_name[0]
if img_name == '':
pass
else:
target_image_name = "images/tmpx.jpg"
shutil.copy(img_name, target_image_name)
self.to_predict_name = target_image_name
img_init = cv2.imread(self.to_predict_name)
h, w, c = img_init.shape
scale = 400 / h
img_show = cv2.resize(img_init, (0, 0), fx=scale, fy=scale)
cv2.imwrite("images/show.png", img_show)
img_init = cv2.resize(img_init, (224, 224))
cv2.imwrite('images/target.png', img_init)
self.img_label.setPixmap(QPixmap("images/show.png"))
def predict_img(self):
img = Image.open('images/target.png')
im_height = 224
im_width = 224
num_classes = 6
_R_MEAN = 123.68
_G_MEAN = 116.78
_B_MEAN = 103.94
class SELayer(tf.keras.Model):
def __init__(self, filters, reduction=16):
super(SELayer, self).__init__()
self.filters = filters
self.reduction = reduction
self.GAP = tf.keras.layers.GlobalAveragePooling2D()
self.FC = tf.keras.models.Sequential([
tf.keras.layers.Dense(units=self.filters // self.reduction, input_shape=(self.filters,)),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Dense(units=filters),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Activation('sigmoid')
])
self.Multiply = tf.keras.layers.Multiply()
def call(self, inputs, training=None, mask=None):
x = self.GAP(inputs)
x = self.FC(x)
x = self.Multiply([x, inputs])
return x
def build_graph(self, input_shape):
input_shape_without_batch = input_shape[1:]
self.build(input_shape)
inputs = tf.keras.Input(shape=input_shape_without_batch)
_ = self.call(inputs)
img = np.array(img).astype(np.float32)
img = img - [_R_MEAN, _G_MEAN, _B_MEAN]
feature = resnet50(num_classes=num_classes, include_top=False)
feature.trainable = False
model = tf.keras.Sequential([feature,
SELayer(2048),
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(1024),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(num_classes, activation='softmax')
])
weights_path = 'models/resnet_last2_fish6_epoch5.ckpt'
assert len(glob.glob(weights_path + "*")), "cannot find {}".format(weights_path)
model.load_weights(weights_path)
outputs = np.squeeze(model.predict(img.reshape(1, 224, 224, 3)))
result_index = int(np.argmax(outputs))
print(result_index, self.class_names[result_index])
result = self.class_names[result_index]
# print(result)
self.result.setText(result)
def closeEvent(self, event):
reply = QMessageBox.question(self,
'退出',
"是否要退出程序?",
QMessageBox.Yes | QMessageBox.No,
QMessageBox.No)
if reply == QMessageBox.Yes:
self.close()
event.accept()
else:
event.ignore()
if __name__ == "__main__":
app = QApplication(sys.argv)
x = MainWindow()
x.show()
sys.exit(app.exec_())
效果图:
鱼类模型及其预处理文件和相关前端显示文件已经上传到资源里去了,可以自行下载。