EfficientNet迁移学习(五) —— 网络预测(predict.py)

基本流程

  1. 加载预测参数
  2. 加载网络结构,恢复权重
  3. 循环预测,输出结果
  4. 评价指标

完整代码

import cv2
import os
import shutil
import numpy as np
import tensorflow as tf
import config as cfg
from tensorflow.contrib.slim.nets import resnet_v2
from Model import *
import matplotlib.pyplot as plt
# from matplotlib.pyplot import plot, savefig, plt
import matplotlib
matplotlib.use('Agg')
os.environ['CUDA_VISIBLE_DEVICES'] = '0'


class ResNetTest(object):
    def __init__(self, weight_file):
        self.input_size = cfg.train.input_size
        self.num_classes = 2
        self.moving_ave_decay = cfg.efficientnet.moving_ave_decay
        self.weight_file = weight_file
        self.batch_norm_decay = 0.997
        self.base_architecture = cfg.efficientnet.base_architecture[0]
        self.model_name = 'efficientnet-b0'
        self.mode = False  # 训练或者测试的标志

        self.inputs = tf.placeholder(tf.float32,
                                     shape=(None, self.input_size[0], self.input_size[1], 3),
                                     name='input')

        self.label_c = tf.placeholder(tf.int64, shape=(cfg.train.batch_size,), name='label')
        self.trainable = tf.placeholder(dtype=tf.bool, name='training')

        if self.base_architecture in ['resnet_v1_50', 'resnet_v2_50', 'resnet_v2_101']:
            base_model = resnet_v2.resnet_v2_50
            with tf.contrib.slim.arg_scope(resnet_v2.resnet_arg_scope(batch_norm_decay=self.batch_norm_decay)):
                self.logits, self.end_points = base_model(self.inputs,
                                                          num_classes=self.num_classes,
                                                          is_training=self.trainable)
        else:
            print('using ', self.model_name)
            override_params = {}
            # if FLAGS.batch_norm_momentum is not None:
            #     override_params['batch_norm_momentum'] = FLAGS.batch_norm_momentum
            # if FLAGS.batch_norm_epsilon is not None:
            #     override_params['batch_norm_epsilon'] = FLAGS.batch_norm_epsilon
            # if FLAGS.dropout_rate is not None:
            #     override_params['dropout_rate'] = FLAGS.dropout_rate
            # if FLAGS.survival_prob is not None:
            #     override_params['survival_prob'] = FLAGS.survival_prob
            # if FLAGS.data_format:
            #     override_params['data_format'] = FLAGS.data_format
            # if FLAGS.num_label_classes:
            override_params['num_classes'] = self.num_classes
            # if FLAGS.depth_coefficient:
            #     override_params['depth_coefficient'] = FLAGS.depth_coefficient
            # if FLAGS.width_coefficient:
            #     override_params['width_coefficient'] = FLAGS.width_coefficient

            model_builder = model_builder_factory.get_model_builder(self.model_name)

            self.logits, _ = model_builder.build_model(self.inputs,
                                                       self.model_name,
                                                       self.trainable,
                                                       override_params=override_params)

        # with tf.name_scope('ema'):
        #     ema_obj = tf.train.ExponentialMovingAverage(self.moving_ave_decay)

        self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
        # self.saver = tf.train.Saver(ema_obj.variables_to_restore())
        self.saver = tf.train.Saver(tf.global_variables())
        # self.saver.restore(self.sess, self.weight_file)

        # sigmoid计算准确度
        logit_soft = tf.nn.sigmoid(self.logits)
        print(logit_soft)
        # logit_squeeze = tf.squeeze(logit_soft, axis=[1, 2])
        self.predict_ = tf.argmax(logit_soft, axis=1)

    def predict(self):
        # if os.path.exists(model_name + '.txt'):
        #     os.remove(model_name + '.txt')

        # 运行所有epoch的模型
        acc_list = []
        sensitive_list = []
        specify_list = []

        # =========================================================================================
        # 设定测试参数
        flag = 'test_data'  # valid_data test_data

        # 0908_roc_color.txt 0915_roc_color.txt 0920_roc_color.txt
        # valid_1025_color.txt
        data_name = '0920_roc_color.txt'
        model_name = 'model_1111_test4'
        # =========================================================================================

        for i in range(19):
            j = 60
            print('--------------------------epoch-', i+j)
            # tf.reset_default_graph()
            # self.sess.run(tf.global_variables_initializer())
            weight_file = "./checkpoint/model/" + model_name + "/model-" + str(i + j)
            print('restore weight file:', weight_file)
            self.saver.restore(self.sess, weight_file)

            name_list = open(cfg.train.root_path + data_name, 'r')

            label_list = []
            predict_list = []
            image_list = []
            cobb = []
            test_name = []
            for line in name_list:
                line = line.strip()
                s1 = line.split(' ')

                # # 测试集测试代码
                if flag == 'test_data':
                    if len(s1) != 1:
                        file = cfg.train.root_path+s1[0]
                        # print(file)
                        img = cv2.imread(file)
                        # print('图像大小:', img.shape)
                        test_name.append(s1[0])
                        # print(image.shape)
                        # cv2.imshow('image', image)
                        # cv2.waitKey(0)
                        if float(s1[1]) >= 10:
                            value = 1
                            label_list.append(value)
                            cobb.append(s1[-1])
                        else:
                            value = 0
                            label_list.append(value)
                            cobb.append(s1[-1])

                        img = cv2.resize(img, (self.input_size[1], self.input_size[0]))
                        img = img.astype(np.float32)

                        # std_bgr = [6.550119402970968, 6.312448303275082, 8.977662213055952]
                        # mean_bgr = [33.65559903, 120.61937841, 116.81338165]
                        # img = (img - np.array([33.65559903, 120.61937841, 116.81338165])) / np.array(std_bgr)

                        img = img/255.0

                        img = np.reshape(img, (1, self.input_size[0], self.input_size[1], 3))

                        predict = self.sess.run([self.predict_],
                                                feed_dict={self.inputs: img, self.trainable: False})
                        predict_list.append(predict[0][0])

                    else:
                        print('name error')
                        print('s1: ', s1)

                # ===================================================================
                #           下面代码片段是在验证集上的测试,符合预期效果
                # ===================================================================
                else:
                    if len(s1) == 2:
                        data_image = cfg.test.image_path + s1[0]
                        if float(s1[-1]) >= 10:
                            data_label = 1

                        # elif float(d1[-1]) >= 10:
                        #     continue
                        else:
                            data_label = 0

                    else:
                        ss = ' '.join(s1[:-1])
                        data_image = cfg.test.image_path + ss

                        if float(s1[-1]) >= 10:
                            data_label = 1
                        # elif float(d1[-1]) >= 10:
                        #     continue
                        else:
                            data_label = 0
                    # print(data_image)
                    img = cv2.imread(data_image)
                    # human_data.show_image('ori image', img)

                    img = cv2.resize(img, (self.input_size[1], self.input_size[0]))
                    # human_data.show_image('resize', img)

                    img = img.astype(np.float32)

                    # 归一化1
                    # std_bgr = [6.550119402970968, 6.312448303275082, 8.977662213055952]
                    # mean_bgr = [33.65559903, 120.61937841, 116.81338165]
                    # img = (img - np.array([33.65559903, 120.61937841, 116.81338165])) / np.array(std_bgr)

                    # 归一化2
                    # img = (img - np.mean(img, axis=(0, 1))) / (np.std(img, axis=(0, 1)) + 1e-8)

                    # 归一化3
                    img = img/255.0

                    img = np.reshape(img, (1, self.input_size[0], self.input_size[1], 3))

                    label_list.append(data_label)
                    image_list.append(img)
                    predict = self.sess.run([self.predict_], feed_dict={self.inputs: img, self.trainable: False})
                    predict_list.append(predict[0][0])

            name_list.close()
            # for j in range(len(label_list)):
            #     if label_list[j] != predict_list[j]:
            #         print('真实值 预测值 Cobb角: ', label_list[j], predict_list[j], cobb[j], test_name[j])

            # print('真实值:', label_list)
            # print('预测值:', predict_list)
            confusion_matrix = tf.math.confusion_matrix(np.hstack(label_list),
                                                        np.hstack(predict_list),
                                                        num_classes=2)
            confusion_matrix_ = self.sess.run(confusion_matrix)

            TN = confusion_matrix_[0][0]
            FP = confusion_matrix_[0][1]
            FN = confusion_matrix_[1][0]
            TP = confusion_matrix_[1][1]

            acc = (TP + TN) / (TP + TN + FP + FN)
            sensitive = TP / (TP + FN)
            specify = TN / (TN + FP)

            print('混淆矩阵\n', confusion_matrix_)
            print('准确度, 灵敏度, 特异度: ', acc, sensitive, specify)
            if sensitive <= 0.2 or sensitive >= 0.98:
                print('数值不可信')
                continue

            if specify <= 0.2 or specify >= 0.98:
                print('数值不可信')
                continue

            acc_list.append(acc)
            sensitive_list.append(sensitive)
            specify_list.append(specify)

            # txt = open(model_name + '.txt', 'a+')
            # txt.write('准确度, 灵敏度, 特异度: '
            #           + str(round(acc, 3)) + ' '
            #           + str(round(sensitive, 3)) + ' '
            #           + str(round(specify, 3)) + '\n')

        # 计算一定多个模型的平均准确度
        mean_acc = np.mean(acc_list)
        mean_sensitive = np.mean(sensitive_list)
        mean_specify = np.mean(specify_list)
        print('====================================')
        print('平均准确度, 平均灵敏度, 平均特异度: ', mean_acc, mean_sensitive, mean_specify)


if __name__ == '__main__':
    res = ResNetTest(weight_file=None)
    res.predict()
    res.sess.close()
    # res.vis_feature_map()

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值