ResNet迁移学习(五)—— 网络预测(predict.py)

预测流程

  1. 加载预测参数
  2. 加载网络结构,恢复权重
  3. 循环预测,输出结果
  4. 评价指标

代码展示

以下代码是predict.py

import cv2
import os
import shutil
import numpy as np
import tensorflow as tf
import config as cfg
from tensorflow.contrib.slim.nets import resnet_v2
from Model import *


class ResNetTest(object):
    def __init__(self):
    	# 0. 加载相应的参数
        self.input_size = cfg.Train.Input_Size
        self.num_classes = 2
        self.moving_ave_decay = cfg.ResNet.Moving_Ave_Decay
        self.weight_file = cfg.Test.WEIGHT_FILE
        self.batch_norm_decay = 0.997
        self.base_architecture = cfg.ResNet.Base_Architecture[1]
        # self.pre_trained_model = cfg.ResNet.Pre_Trained_Model[0]

        self.inputs = tf.placeholder(tf.float32,
                                     shape=(None, cfg.Train.Input_Size[0], cfg.Train.Input_Size[1], 3),
                                     name='input')

        self.label_c = tf.placeholder(tf.int64, shape=(cfg.Train.Batch_Size,), name='label')
        self.trainable = tf.placeholder(dtype=tf.bool, name='training')
		
		# 1. 加载定义的网络结构
        base_model = resnet_v2.resnet_v2_50
        with tf.contrib.slim.arg_scope(resnet_v2.resnet_arg_scope(batch_norm_decay=self.batch_norm_decay)):
            self.logits, self.end_points = base_model(self.inputs,
                                                      num_classes=self.num_classes,
                                                      is_training=self.trainable)

        with tf.name_scope('ema'):
            ema_obj = tf.train.ExponentialMovingAverage(self.moving_ave_decay)
		
		# 3. 恢复训练好的网络权重参数
        self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
        # self.saver = tf.train.Saver(ema_obj.variables_to_restore())
        self.saver = tf.train.Saver(tf.global_variables())
        self.saver.restore(self.sess, self.weight_file)

    def predict(self):
        mode = cfg.Test.Mode[0]
			
        print('Run Txt Mode ...')
        # 4. 测试集数据集
        name_list = open(cfg.Train.Root_Path + '0920_roc.txt', 'r')

        label_list = []
        predict_list = []
        image_list = []
        cobb = []
        test_name = []
        for line in name_list:

            line = line.strip()
            s1 = line.split(' ')

            # # 测试集测试代码
            if len(s1) == 2:
                file = cfg.Train.Root_Path+s1[0]
                # print(file)
                image = cv2.imread(file)
                # print('图像大小:', image.shape)
                test_name.append(s1[0])
                # print(image.shape)
                # cv2.imshow('image', image)
                # cv2.waitKey(0)
                if float(s1[1]) >= 10:
                    value = 1
                    label_list.append(value)
                    cobb.append(s1[-1])
                else:
                    value = 0
                    label_list.append(value)
                    cobb.append(s1[-1])

                img = cv2.resize(image, (cfg.Train.Input_Size[1], cfg.Train.Input_Size[0]))

                img = img.astype(np.float32)
                img = img/255.0
                img = np.reshape(img, (1, cfg.Train.Input_Size[0], cfg.Train.Input_Size[1], 3))

                predict = self.sess.run([predict_], feed_dict={self.inputs: img, self.trainable: False})
                if predict[0][0] != value:
                    print("预测值: ", predict[0])
                    # print('ss: ', predict[0].shape)
                    print("value: ", value)
                    print('======================')

                predict_list.append(predict[0][0])

            else:
                print('name error')
                print('s1: ', s1)

        print('真实值:', label_list)
        print('预测值:', predict_list)
        # 5. 混淆矩阵,获得灵敏度和特异度指标值
        confusion_matrix = tf.math.confusion_matrix(np.hstack(label_list), np.hstack(predict_list), num_classes=2)
        confusion_matrix_ = self.sess.run(confusion_matrix)

        TN = confusion_matrix_[0][0]
        FP = confusion_matrix_[0][1]
        FN = confusion_matrix_[1][0]
        TP = confusion_matrix_[1][1]

        acc = (TP + TN) / (TP + TN + FP + FN)
        sensitive = TP / (TP + FN)
        specify = TN / (TN + FP)

        print('混淆矩阵\n', confusion_matrix_)
        print('准确度, 灵敏度, 特异度: ', acc, sensitive, specify)


if __name__ == '__main__':
    res = ResNetTest()
    res.predict()

以下代码是config.py,设置训练,验证,和预测过程中相应的参数:


from easydict import EasyDict as edict

cfg = edict()

# Consumers can get config by: from config import cfg
# cfg = cfg

# ResNet options
cfg.ResNet = edict()

cfg.ResNet.Num_Classes = 2
cfg.ResNet.Moving_Ave_Decay = 0.9995
cfg.ResNet.Pre_Trained_Model = ['./resnet_v1_50/resnet_v1_50.ckpt',
                                './resnet_v2_50/resnet_v2_50.ckpt',
                                './vgg16/vgg_16_2016_08_28/vgg_16.ckpt',
                                './resnet_v2_101/resnet_v2_101.ckpt']

cfg.ResNet.Base_Architecture = ['resnet_v1_50', 'resnet_v2_50', 'vgg16', 'resnet_v2_101']
cfg.ResNet.Batch_Norm_Decay = 0.997

# Train options
cfg.Train = edict()

cfg.Train.Root_Path = '../0929_result/'
cfg.Train.Train_Set = "../0929_result/train_0929.txt"
cfg.Train.Valid_Set = "../0929_result/valid_0914.txt"
# 训练集通道均值 [108.73639117  92.45784024   0.        ]
cfg.Train.Train_Num = 2376
cfg.Train.Valid_Num = 256
cfg.Train.Batch_Size = 32
cfg.Train.Step_Per_Epoch = 2376//32

cfg.Train.Input_Size = [250, 250, 3]
cfg.Train.Learn_Rate_Init = 0.0001
cfg.Train.Learn_Rate_End = 1e-6
cfg.Train.Warmup_Epochs = 15
cfg.Train.First_Stage_Epochs = 50
cfg.Train.Second_Stage_Epochs = 120

cfg.Train.Loss_Function = ['sigmoid', 'softmax']

cfg.Train.Log = './checkpoint/log/log_0929_0/'
cfg.Train.Save_Model = './checkpoint/model/model_0929_0/model'

# TEST options
cfg.Test = edict()
cfg.Test.Mode = ['txt', 'image']
cfg.Test.Image_Path = '../0914_rgb_new/'  
cfg.Test.WEIGHT_FILE = "./checkpoint/model/model_0929_0/model-150"
  • 0
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值