基于ResNet50以及SELayer注意力机制的迁移学习海洋鱼类识别系统及其前端展示

基于ResNet50以及SELayer注意力机制的海洋鱼类识别系统,主要通过ResNet50对鱼类进行有关的识别,其中增加了SELayer注意力机制使其增加深度学习的准确性,最后通过PyQt5的库进行了前端的展示。

ResNet50代码:

from tensorflow.keras import layers, Model, Sequential



class BasicBlock(layers.Layer):
    expansion = 1

    def __init__(self, out_channel, strides=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__(**kwargs)
        self.conv1 = layers.Conv2D(out_channel, kernel_size=3, strides=strides,
                                   padding="SAME", use_bias=False)
        self.bn1 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)
        # -----------------------------------------
        self.conv2 = layers.Conv2D(out_channel, kernel_size=3, strides=1,
                                   padding="SAME", use_bias=False)
        self.bn2 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5)
        # -----------------------------------------
        self.downsample = downsample
        self.relu = layers.ReLU()
        self.add = layers.Add()

    def call(self, inputs, training=False):
        identity = inputs
        if self.downsample is not None:
            identity = self.downsample(inputs)

        x = self.conv1(inputs)
        x = self.bn1(x, training=training)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x, training=training)

        x = self.add([identity, x])
        x = self.relu(x)

        return x


class Bottleneck(layers.Layer):

    expansion = 4

    def __init__(self, out_channel, strides=1, downsample=None, **kwargs):
        super(Bottleneck, self).__init__(**kwargs)
        self.conv1 = layers.Conv2D(out_channel, kernel_size=1, use_bias=False, name="conv1")
        self.bn1 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5, name="conv1/BatchNorm")
        # -----------------------------------------
        self.conv2 = layers.Conv2D(out_channel, kernel_size=3, use_bias=False,
                                   strides=strides, padding="SAME", name="conv2")
        self.bn2 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5, name="conv2/BatchNorm")
        # -----------------------------------------
        self.conv3 = layers.Conv2D(out_channel * self.expansion, kernel_size=1, use_bias=False, name="conv3")
        self.bn3 = layers.BatchNormalization(momentum=0.9, epsilon=1e-5, name="conv3/BatchNorm")
        # -----------------------------------------
        self.relu = layers.ReLU()
        self.downsample = downsample
        self.add = layers.Add()

    def call(self, inputs, training=False):
        identity = inputs
        if self.downsample is not None:
            identity = self.downsample(inputs)

        x = self.conv1(inputs)
        x = self.bn1(x, training=training)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x, training=training)
        x = self.relu(x)

        x = self.conv3(x)
        x = self.bn3(x, training=training)

        x = self.add([x, identity])
        x = self.relu(x)

        return x


def _make_layer(block, in_channel, channel, block_num, name, strides=1):
    downsample = None
    if strides != 1 or in_channel != channel * block.expansion:
        downsample = Sequential([
            layers.Conv2D(channel * block.expansion, kernel_size=1, strides=strides,
                          use_bias=False, name="conv1"),
            layers.BatchNormalization(momentum=0.9, epsilon=1.001e-5, name="BatchNorm")
        ], name="shortcut")

    layers_list = []
    layers_list.append(block(channel, downsample=downsample, strides=strides, name="unit_1"))

    for index in range(1, block_num):
        layers_list.append(block(channel, name="unit_" + str(index + 1)))

    return Sequential(layers_list, name=name)


def _resnet(block, blocks_num, im_width=224, im_height=224, num_classes=245, include_top=True):
    # tensorflow中的tensor通道排序是NHWC
    # (None, 224, 224, 3)
    input_image = layers.Input(shape=(im_height, im_width, 3), dtype="float32")
    x = layers.Conv2D(filters=64, kernel_size=7, strides=2,
                      padding="SAME", use_bias=False, name="conv1")(input_image)
    x = layers.BatchNormalization(momentum=0.9, epsilon=1e-5, name="conv1/BatchNorm")(x)
    x = layers.ReLU()(x)
    x = layers.MaxPool2D(pool_size=3, strides=2, padding="SAME")(x)

    x = _make_layer(block, x.shape[-1], 64, blocks_num[0], name="block1")(x)
    x = _make_layer(block, x.shape[-1], 128, blocks_num[1], strides=2, name="block2")(x)
    x = _make_layer(block, x.shape[-1], 256, blocks_num[2], strides=2, name="block3")(x)
    x = _make_layer(block, x.shape[-1], 512, blocks_num[3], strides=2, name="block4")(x)

    if include_top:
        x = layers.GlobalAvgPool2D()(x)  # pool + flatten
        x = layers.Dense(num_classes, name="logits")(x)
        predict = layers.Softmax()(x)
    else:
        predict = x

    model = Model(inputs=input_image, outputs=predict)

    return model


def resnet18(im_width=224, im_height=224, num_classes=245, include_top=True):
    return _resnet(BasicBlock, [2, 2, 2, 2], im_width, im_height, num_classes, include_top)

def resnet34(im_width=224, im_height=224, num_classes=245, include_top=True):
    return _resnet(BasicBlock, [3, 4, 6, 3], im_width, im_height, num_classes, include_top)


def resnet50(im_width=224, im_height=224, num_classes=245, include_top=False):
    return _resnet(Bottleneck, [3, 4, 6, 3], im_width, im_height, num_classes, include_top)


def resnet101(im_width=224, im_height=224, num_classes=245, include_top=True):
    return _resnet(Bottleneck, [3, 4, 23, 3], im_width, im_height, num_classes, include_top)

将上面的ResNet形成model.py文件并且导入到运行代码中,运行代码如下,主要把路径改为自己的就行:

from tensorflow.keras import layers, Model, Sequential
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
from time import *
from model import resnet50
import json
from attention import cbam_block
# from tensorflow.keras import layers, models, Input
# from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout

_R_MEAN = 123.68
_G_MEAN = 116.78
_B_MEAN = 103.94


def pre_function(img):
    # img = im.open('test.jpg')
    # img = np.array(img).astype(np.float32)
    img = img - [_R_MEAN, _G_MEAN, _B_MEAN]

    return img
train_image_generator = ImageDataGenerator(horizontal_flip=True,
                                           preprocessing_function=pre_function, validation_split=0.2)
validation_image_generator = ImageDataGenerator(preprocessing_function=pre_function, validation_split=0.2)
def data_load(data_dir, img_height, img_width, batch_size):
    train_ds = train_image_generator.flow_from_directory(
        data_dir,
        class_mode='categorical',
        shuffle=True,
        subset="training",
        seed=123,
        color_mode="rgb",
        target_size=(img_height, img_width),
        batch_size=batch_size)
    val_ds = validation_image_generator.flow_from_directory(
        data_dir,
        class_mode='categorical',
        shuffle=True,
        subset="validation",
        seed=123,
        color_mode="rgb",
        target_size=(img_height, img_width),
        batch_size=batch_size)



    class_names = train_ds.class_indices

    class_indices = train_ds.class_indices

    # transform value and key of dict
    inverse_dict = dict((val, key) for key, val in class_indices.items())
    json_str = json.dumps(inverse_dict, indent=4)
    with open('models/class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    return train_ds, val_ds, class_names


class SELayer(tf.keras.Model):
    def __init__(self, filters, reduction=16):
        super(SELayer, self).__init__()
        self.filters = filters
        self.reduction = reduction
        self.GAP = tf.keras.layers.GlobalAveragePooling2D()
        self.FC = tf.keras.models.Sequential([
            tf.keras.layers.Dense(units=self.filters // self.reduction, input_shape=(self.filters, )),
            tf.keras.layers.Dropout(0.5),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('relu'),
            tf.keras.layers.Dense(units=filters),
            tf.keras.layers.Dropout(0.5),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.Activation('sigmoid')
        ])
        self.Multiply = tf.keras.layers.Multiply()

    def call(self, inputs, training=None, mask=None):
        x = self.GAP(inputs)
        x = self.FC(x)
        x = self.Multiply([x, inputs])
        return x

    def build_graph(self, input_shape):
        input_shape_without_batch = input_shape[1:]
        self.build(input_shape)
        inputs = tf.keras.Input(shape=input_shape_without_batch)
        _ = self.call(inputs)

def model_load(IMG_SHAPE=(224, 224, 3), class_num=6):
    feature = resnet50(im_width=224, im_height=224, num_classes=class_num, include_top=False)
    feature.trainable = False
    pre_weights_path = './pretrain_weights.ckpt'
    feature.load_weights(pre_weights_path)
    model = tf.keras.Sequential([
                                 feature,
                                 SELayer(2048),
                                 tf.keras.layers.GlobalAveragePooling2D(),
                                 tf.keras.layers.Dense(1024),
                                 tf.keras.layers.Dropout(0.5),
                                 tf.keras.layers.Dense(class_num, activation='softmax')
                                 ]
    )
    model.summary()
    model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
    return model
def show_loss_acc(history):
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']

    loss = history.history['loss']
    val_loss = history.history['val_loss']

    plt.figure(figsize=(8, 8))
    plt.subplot(2, 1, 1)
    plt.plot(acc, label='Training Accuracy')
    plt.plot(val_acc, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.ylabel('Accuracy')
    plt.ylim([min(plt.ylim()), 1])
    plt.title('Training and Validation Accuracy')

    plt.subplot(2, 1, 2)
    plt.plot(loss, label='Training Loss')
    plt.plot(val_loss, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.ylabel('Cross Entropy')
    plt.title('Training and Validation Loss')
    plt.xlabel('epoch')
    plt.savefig('results_reset_last2_fish6.png', dpi=100)

def train(epochs):
    begin_time = time()
    train_ds, val_ds, class_names = data_load("C:/Users/lenovo/Desktop/fish_image23/fish_image", 224, 224, 16)
    print(class_names)
    model = model_load(class_num=len(class_names))
    history = model.fit(train_ds, validation_data=val_ds, epochs=epochs)
    model.save_weights("models/resnet_last2_fish6_epoch5.ckpt", save_format='tf')
    end_time = time()
    run_time = end_time - begin_time
    print('该循环程序运行时间:', run_time, "s")
    show_loss_acc(history)

if __name__ == '__main__':
    train(epochs=5)

 将运行出来的结果使用前端展示出来,代码如下:

import tensorflow as tf
from PyQt5.QtGui import *
from PyQt5.QtCore import *
from PyQt5.QtWidgets import *
import sys
import cv2
from PIL import Image
import numpy as np
import shutil
from model import resnet50
import glob


class MainWindow(QTabWidget):
    def __init__(self):
        super().__init__()
        self.setWindowIcon(QIcon('images/img.png'))
        self.setWindowTitle('海洋鱼类分类系统')
        # self.model = tf.keras.models.load_model("models/lenet_245_epoch5.h5")
        self.to_predict_name = "images/fish1.png"
        self.class_names = ['Amphiprion clarkia', 'Chaetodon lunulatus', 'Chaetodon trifascialis', 'Chromis chrysura', 'Dascyllus reticulatus', 'Pletrogly-phidodon dickii']
        self.resize(900, 700)
        self.initUI()

    def initUI(self):
        main_widget = QWidget()
        main_layout = QHBoxLayout()
        font = QFont('楷体', 15)

        left_widget = QWidget()
        left_layout = QVBoxLayout()
        img_title = QLabel("样本")
        img_title.setFont(font)
        img_title.setAlignment(Qt.AlignCenter)
        self.img_label = QLabel()
        img_init = cv2.imread(self.to_predict_name)
        h, w, c = img_init.shape
        scale = 400 / h
        img_show = cv2.resize(img_init, (0, 0), fx=scale, fy=scale)
        cv2.imwrite("images/show.png", img_show)
        img_init = cv2.resize(img_init, (224, 224))
        cv2.imwrite('images/target.png', img_init)
        self.img_label.setPixmap(QPixmap("images/show.png"))
        left_layout.addWidget(img_title)
        left_layout.addWidget(self.img_label, 1, Qt.AlignCenter)
        # left_layout.setAlignment(Qt.AlignCenter)
        left_widget.setLayout(left_layout)

        right_widget = QWidget()
        right_layout = QVBoxLayout()
        btn_change = QPushButton(" 上传图片 ")
        btn_change.clicked.connect(self.change_img)
        btn_change.setFont(font)
        btn_predict = QPushButton(" 开始识别 ")
        btn_predict.setFont(font)
        btn_predict.clicked.connect(self.predict_img)

        label_result = QLabel(' 海洋鱼类名称 ')
        self.result = QLabel("等待识别")
        label_result.setFont(QFont('楷体', 24))
        self.result.setFont(QFont('楷体', 24))

        right_layout.addWidget(label_result, 0, Qt.AlignCenter)
        right_layout.addStretch(0)
        right_layout.addWidget(self.result, 0, Qt.AlignCenter)
        right_layout.addWidget(btn_change)
        right_layout.addWidget(btn_predict)
        right_layout.addStretch()
        right_widget.setLayout(right_layout)

        # 关于页面
        about_widget = QWidget()
        about_layout = QVBoxLayout()
        about_title = QLabel('欢迎使用海洋鱼类分类系统')
        about_title.setFont(QFont('楷体', 18))
        about_title.setAlignment(Qt.AlignCenter)
        about_img = QLabel()
        about_img.setPixmap(QPixmap('images/about1.png'))
        about_img.setAlignment(Qt.AlignCenter)
        label_super = QLabel('<a href="https://blog.csdn.net/qq_46136833?type=blog">鱼类识别设计</a>')
        label_super.setFont(QFont('楷体', 12))
        label_super.setOpenExternalLinks(True)
        label_super.setAlignment(Qt.AlignRight)
        about_layout.addWidget(about_title)
        about_layout.addStretch()
        about_layout.addWidget(about_img)
        about_layout.addStretch()
        about_layout.addWidget(label_super)
        about_widget.setLayout(about_layout)

        main_layout.addWidget(left_widget)
        main_layout.addWidget(right_widget)
        main_widget.setLayout(main_layout)
        self.addTab(main_widget, '主页')
        self.addTab(about_widget, '关于')
        self.setTabIcon(0, QIcon('images/主页面.png'))
        self.setTabIcon(1, QIcon('images/关于.png'))

    def change_img(self):
        openfile_name = QFileDialog.getOpenFileName(self, 'chose files', '', 'Image files(*.jpg *.png *jpeg)')
        img_name = openfile_name[0]
        if img_name == '':
            pass
        else:
            target_image_name = "images/tmpx.jpg"
            shutil.copy(img_name, target_image_name)
            self.to_predict_name = target_image_name
            img_init = cv2.imread(self.to_predict_name)
            h, w, c = img_init.shape
            scale = 400 / h
            img_show = cv2.resize(img_init, (0, 0), fx=scale, fy=scale)
            cv2.imwrite("images/show.png", img_show)
            img_init = cv2.resize(img_init, (224, 224))
            cv2.imwrite('images/target.png', img_init)
            self.img_label.setPixmap(QPixmap("images/show.png"))

    def predict_img(self):
        img = Image.open('images/target.png')
        im_height = 224
        im_width = 224
        num_classes = 6
        _R_MEAN = 123.68
        _G_MEAN = 116.78
        _B_MEAN = 103.94

        class SELayer(tf.keras.Model):
            def __init__(self, filters, reduction=16):
                super(SELayer, self).__init__()
                self.filters = filters
                self.reduction = reduction
                self.GAP = tf.keras.layers.GlobalAveragePooling2D()
                self.FC = tf.keras.models.Sequential([
                    tf.keras.layers.Dense(units=self.filters // self.reduction, input_shape=(self.filters,)),
                    tf.keras.layers.Dropout(0.5),
                    tf.keras.layers.BatchNormalization(),
                    tf.keras.layers.Activation('relu'),
                    tf.keras.layers.Dense(units=filters),
                    tf.keras.layers.Dropout(0.5),
                    tf.keras.layers.BatchNormalization(),
                    tf.keras.layers.Activation('sigmoid')
                ])
                self.Multiply = tf.keras.layers.Multiply()

            def call(self, inputs, training=None, mask=None):
                x = self.GAP(inputs)
                x = self.FC(x)
                x = self.Multiply([x, inputs])
                return x

            def build_graph(self, input_shape):
                input_shape_without_batch = input_shape[1:]
                self.build(input_shape)
                inputs = tf.keras.Input(shape=input_shape_without_batch)
                _ = self.call(inputs)
        img = np.array(img).astype(np.float32)
        img = img - [_R_MEAN, _G_MEAN, _B_MEAN]
        feature = resnet50(num_classes=num_classes, include_top=False)
        feature.trainable = False
        model = tf.keras.Sequential([feature,
                                     SELayer(2048),
                                     tf.keras.layers.GlobalAveragePooling2D(),
                                     tf.keras.layers.Dense(1024),
                                     tf.keras.layers.Dropout(0.5),
                                     tf.keras.layers.Dense(num_classes, activation='softmax')
                                     ])
        weights_path = 'models/resnet_last2_fish6_epoch5.ckpt'
        assert len(glob.glob(weights_path + "*")), "cannot find {}".format(weights_path)
        model.load_weights(weights_path)
        outputs = np.squeeze(model.predict(img.reshape(1, 224, 224, 3)))
        result_index = int(np.argmax(outputs))
        print(result_index, self.class_names[result_index])
        result = self.class_names[result_index]
        # print(result)
        self.result.setText(result)

    def closeEvent(self, event):
        reply = QMessageBox.question(self,
                                     '退出',
                                     "是否要退出程序?",
                                     QMessageBox.Yes | QMessageBox.No,
                                     QMessageBox.No)
        if reply == QMessageBox.Yes:
            self.close()
            event.accept()
        else:
            event.ignore()


if __name__ == "__main__":
    app = QApplication(sys.argv)
    x = MainWindow()
    x.show()
    sys.exit(app.exec_())

效果图:

 鱼类模型及其预处理文件和相关前端显示文件已经上传到资源里去了,可以自行下载。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值