预测流程
- 加载预测参数
- 加载网络结构,恢复权重
- 循环预测,输出结果
- 评价指标
代码展示
以下代码是predict.py
import cv2
import os
import shutil
import numpy as np
import tensorflow as tf
import config as cfg
from tensorflow.contrib.slim.nets import resnet_v2
from Model import *
class ResNetTest(object):
def __init__(self):
# 0. 加载相应的参数
self.input_size = cfg.Train.Input_Size
self.num_classes = 2
self.moving_ave_decay = cfg.ResNet.Moving_Ave_Decay
self.weight_file = cfg.Test.WEIGHT_FILE
self.batch_norm_decay = 0.997
self.base_architecture = cfg.ResNet.Base_Architecture[1]
# self.pre_trained_model = cfg.ResNet.Pre_Trained_Model[0]
self.inputs = tf.placeholder(tf.float32,
shape=(None, cfg.Train.Input_Size[0], cfg.Train.Input_Size[1], 3),
name='input')
self.label_c = tf.placeholder(tf.int64, shape=(cfg.Train.Batch_Size,), name='label')
self.trainable = tf.placeholder(dtype=tf.bool, name='training')
# 1. 加载定义的网络结构
base_model = resnet_v2.resnet_v2_50
with tf.contrib.slim.arg_scope(resnet_v2.resnet_arg_scope(batch_norm_decay=self.batch_norm_decay)):
self.logits, self.end_points = base_model(self.inputs,
num_classes=self.num_classes,
is_training=self.trainable)
with tf.name_scope('ema'):
ema_obj = tf.train.ExponentialMovingAverage(self.moving_ave_decay)
# 3. 恢复训练好的网络权重参数
self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
# self.saver = tf.train.Saver(ema_obj.variables_to_restore())
self.saver = tf.train.Saver(tf.global_variables())
self.saver.restore(self.sess, self.weight_file)
def predict(self):
mode = cfg.Test.Mode[0]
print('Run Txt Mode ...')
# 4. 测试集数据集
name_list = open(cfg.Train.Root_Path + '0920_roc.txt', 'r')
label_list = []
predict_list = []
image_list = []
cobb = []
test_name = []
for line in name_list:
line = line.strip()
s1 = line.split(' ')
# # 测试集测试代码
if len(s1) == 2:
file = cfg.Train.Root_Path+s1[0]
# print(file)
image = cv2.imread(file)
# print('图像大小:', image.shape)
test_name.append(s1[0])
# print(image.shape)
# cv2.imshow('image', image)
# cv2.waitKey(0)
if float(s1[1]) >= 10:
value = 1
label_list.append(value)
cobb.append(s1[-1])
else:
value = 0
label_list.append(value)
cobb.append(s1[-1])
img = cv2.resize(image, (cfg.Train.Input_Size[1], cfg.Train.Input_Size[0]))
img = img.astype(np.float32)
img = img/255.0
img = np.reshape(img, (1, cfg.Train.Input_Size[0], cfg.Train.Input_Size[1], 3))
predict = self.sess.run([predict_], feed_dict={self.inputs: img, self.trainable: False})
if predict[0][0] != value:
print("预测值: ", predict[0])
# print('ss: ', predict[0].shape)
print("value: ", value)
print('======================')
predict_list.append(predict[0][0])
else:
print('name error')
print('s1: ', s1)
print('真实值:', label_list)
print('预测值:', predict_list)
# 5. 混淆矩阵,获得灵敏度和特异度指标值
confusion_matrix = tf.math.confusion_matrix(np.hstack(label_list), np.hstack(predict_list), num_classes=2)
confusion_matrix_ = self.sess.run(confusion_matrix)
TN = confusion_matrix_[0][0]
FP = confusion_matrix_[0][1]
FN = confusion_matrix_[1][0]
TP = confusion_matrix_[1][1]
acc = (TP + TN) / (TP + TN + FP + FN)
sensitive = TP / (TP + FN)
specify = TN / (TN + FP)
print('混淆矩阵\n', confusion_matrix_)
print('准确度, 灵敏度, 特异度: ', acc, sensitive, specify)
if __name__ == '__main__':
res = ResNetTest()
res.predict()
以下代码是config.py,设置训练,验证,和预测过程中相应的参数:
from easydict import EasyDict as edict
cfg = edict()
# Consumers can get config by: from config import cfg
# cfg = cfg
# ResNet options
cfg.ResNet = edict()
cfg.ResNet.Num_Classes = 2
cfg.ResNet.Moving_Ave_Decay = 0.9995
cfg.ResNet.Pre_Trained_Model = ['./resnet_v1_50/resnet_v1_50.ckpt',
'./resnet_v2_50/resnet_v2_50.ckpt',
'./vgg16/vgg_16_2016_08_28/vgg_16.ckpt',
'./resnet_v2_101/resnet_v2_101.ckpt']
cfg.ResNet.Base_Architecture = ['resnet_v1_50', 'resnet_v2_50', 'vgg16', 'resnet_v2_101']
cfg.ResNet.Batch_Norm_Decay = 0.997
# Train options
cfg.Train = edict()
cfg.Train.Root_Path = '../0929_result/'
cfg.Train.Train_Set = "../0929_result/train_0929.txt"
cfg.Train.Valid_Set = "../0929_result/valid_0914.txt"
# 训练集通道均值 [108.73639117 92.45784024 0. ]
cfg.Train.Train_Num = 2376
cfg.Train.Valid_Num = 256
cfg.Train.Batch_Size = 32
cfg.Train.Step_Per_Epoch = 2376//32
cfg.Train.Input_Size = [250, 250, 3]
cfg.Train.Learn_Rate_Init = 0.0001
cfg.Train.Learn_Rate_End = 1e-6
cfg.Train.Warmup_Epochs = 15
cfg.Train.First_Stage_Epochs = 50
cfg.Train.Second_Stage_Epochs = 120
cfg.Train.Loss_Function = ['sigmoid', 'softmax']
cfg.Train.Log = './checkpoint/log/log_0929_0/'
cfg.Train.Save_Model = './checkpoint/model/model_0929_0/model'
# TEST options
cfg.Test = edict()
cfg.Test.Mode = ['txt', 'image']
cfg.Test.Image_Path = '../0914_rgb_new/'
cfg.Test.WEIGHT_FILE = "./checkpoint/model/model_0929_0/model-150"