基本流程
- 加载预测参数
- 加载网络结构,恢复权重
- 循环预测,输出结果
- 评价指标
完整代码
import cv2
import os
import shutil
import numpy as np
import tensorflow as tf
import config as cfg
from tensorflow.contrib.slim.nets import resnet_v2
from Model import *
import matplotlib.pyplot as plt
# from matplotlib.pyplot import plot, savefig, plt
import matplotlib
matplotlib.use('Agg')
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class ResNetTest(object):
def __init__(self, weight_file):
self.input_size = cfg.train.input_size
self.num_classes = 2
self.moving_ave_decay = cfg.efficientnet.moving_ave_decay
self.weight_file = weight_file
self.batch_norm_decay = 0.997
self.base_architecture = cfg.efficientnet.base_architecture[0]
self.model_name = 'efficientnet-b0'
self.mode = False # 训练或者测试的标志
self.inputs = tf.placeholder(tf.float32,
shape=(None, self.input_size[0], self.input_size[1], 3),
name='input')
self.label_c = tf.placeholder(tf.int64, shape=(cfg.train.batch_size,), name='label')
self.trainable = tf.placeholder(dtype=tf.bool, name='training')
if self.base_architecture in ['resnet_v1_50', 'resnet_v2_50', 'resnet_v2_101']:
base_model = resnet_v2.resnet_v2_50
with tf.contrib.slim.arg_scope(resnet_v2.resnet_arg_scope(batch_norm_decay=self.batch_norm_decay)):
self.logits, self.end_points = base_model(self.inputs,
num_classes=self.num_classes,
is_training=self.trainable)
else:
print('using ', self.model_name)
override_params = {}
# if FLAGS.batch_norm_momentum is not None:
# override_params['batch_norm_momentum'] = FLAGS.batch_norm_momentum
# if FLAGS.batch_norm_epsilon is not None:
# override_params['batch_norm_epsilon'] = FLAGS.batch_norm_epsilon
# if FLAGS.dropout_rate is not None:
# override_params['dropout_rate'] = FLAGS.dropout_rate
# if FLAGS.survival_prob is not None:
# override_params['survival_prob'] = FLAGS.survival_prob
# if FLAGS.data_format:
# override_params['data_format'] = FLAGS.data_format
# if FLAGS.num_label_classes:
override_params['num_classes'] = self.num_classes
# if FLAGS.depth_coefficient:
# override_params['depth_coefficient'] = FLAGS.depth_coefficient
# if FLAGS.width_coefficient:
# override_params['width_coefficient'] = FLAGS.width_coefficient
model_builder = model_builder_factory.get_model_builder(self.model_name)
self.logits, _ = model_builder.build_model(self.inputs,
self.model_name,
self.trainable,
override_params=override_params)
# with tf.name_scope('ema'):
# ema_obj = tf.train.ExponentialMovingAverage(self.moving_ave_decay)
self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
# self.saver = tf.train.Saver(ema_obj.variables_to_restore())
self.saver = tf.train.Saver(tf.global_variables())
# self.saver.restore(self.sess, self.weight_file)
# sigmoid计算准确度
logit_soft = tf.nn.sigmoid(self.logits)
print(logit_soft)
# logit_squeeze = tf.squeeze(logit_soft, axis=[1, 2])
self.predict_ = tf.argmax(logit_soft, axis=1)
def predict(self):
# if os.path.exists(model_name + '.txt'):
# os.remove(model_name + '.txt')
# 运行所有epoch的模型
acc_list = []
sensitive_list = []
specify_list = []
# =========================================================================================
# 设定测试参数
flag = 'test_data' # valid_data test_data
# 0908_roc_color.txt 0915_roc_color.txt 0920_roc_color.txt
# valid_1025_color.txt
data_name = '0920_roc_color.txt'
model_name = 'model_1111_test4'
# =========================================================================================
for i in range(19):
j = 60
print('--------------------------epoch-', i+j)
# tf.reset_default_graph()
# self.sess.run(tf.global_variables_initializer())
weight_file = "./checkpoint/model/" + model_name + "/model-" + str(i + j)
print('restore weight file:', weight_file)
self.saver.restore(self.sess, weight_file)
name_list = open(cfg.train.root_path + data_name, 'r')
label_list = []
predict_list = []
image_list = []
cobb = []
test_name = []
for line in name_list:
line = line.strip()
s1 = line.split(' ')
# # 测试集测试代码
if flag == 'test_data':
if len(s1) != 1:
file = cfg.train.root_path+s1[0]
# print(file)
img = cv2.imread(file)
# print('图像大小:', img.shape)
test_name.append(s1[0])
# print(image.shape)
# cv2.imshow('image', image)
# cv2.waitKey(0)
if float(s1[1]) >= 10:
value = 1
label_list.append(value)
cobb.append(s1[-1])
else:
value = 0
label_list.append(value)
cobb.append(s1[-1])
img = cv2.resize(img, (self.input_size[1], self.input_size[0]))
img = img.astype(np.float32)
# std_bgr = [6.550119402970968, 6.312448303275082, 8.977662213055952]
# mean_bgr = [33.65559903, 120.61937841, 116.81338165]
# img = (img - np.array([33.65559903, 120.61937841, 116.81338165])) / np.array(std_bgr)
img = img/255.0
img = np.reshape(img, (1, self.input_size[0], self.input_size[1], 3))
predict = self.sess.run([self.predict_],
feed_dict={self.inputs: img, self.trainable: False})
predict_list.append(predict[0][0])
else:
print('name error')
print('s1: ', s1)
# ===================================================================
# 下面代码片段是在验证集上的测试,符合预期效果
# ===================================================================
else:
if len(s1) == 2:
data_image = cfg.test.image_path + s1[0]
if float(s1[-1]) >= 10:
data_label = 1
# elif float(d1[-1]) >= 10:
# continue
else:
data_label = 0
else:
ss = ' '.join(s1[:-1])
data_image = cfg.test.image_path + ss
if float(s1[-1]) >= 10:
data_label = 1
# elif float(d1[-1]) >= 10:
# continue
else:
data_label = 0
# print(data_image)
img = cv2.imread(data_image)
# human_data.show_image('ori image', img)
img = cv2.resize(img, (self.input_size[1], self.input_size[0]))
# human_data.show_image('resize', img)
img = img.astype(np.float32)
# 归一化1
# std_bgr = [6.550119402970968, 6.312448303275082, 8.977662213055952]
# mean_bgr = [33.65559903, 120.61937841, 116.81338165]
# img = (img - np.array([33.65559903, 120.61937841, 116.81338165])) / np.array(std_bgr)
# 归一化2
# img = (img - np.mean(img, axis=(0, 1))) / (np.std(img, axis=(0, 1)) + 1e-8)
# 归一化3
img = img/255.0
img = np.reshape(img, (1, self.input_size[0], self.input_size[1], 3))
label_list.append(data_label)
image_list.append(img)
predict = self.sess.run([self.predict_], feed_dict={self.inputs: img, self.trainable: False})
predict_list.append(predict[0][0])
name_list.close()
# for j in range(len(label_list)):
# if label_list[j] != predict_list[j]:
# print('真实值 预测值 Cobb角: ', label_list[j], predict_list[j], cobb[j], test_name[j])
# print('真实值:', label_list)
# print('预测值:', predict_list)
confusion_matrix = tf.math.confusion_matrix(np.hstack(label_list),
np.hstack(predict_list),
num_classes=2)
confusion_matrix_ = self.sess.run(confusion_matrix)
TN = confusion_matrix_[0][0]
FP = confusion_matrix_[0][1]
FN = confusion_matrix_[1][0]
TP = confusion_matrix_[1][1]
acc = (TP + TN) / (TP + TN + FP + FN)
sensitive = TP / (TP + FN)
specify = TN / (TN + FP)
print('混淆矩阵\n', confusion_matrix_)
print('准确度, 灵敏度, 特异度: ', acc, sensitive, specify)
if sensitive <= 0.2 or sensitive >= 0.98:
print('数值不可信')
continue
if specify <= 0.2 or specify >= 0.98:
print('数值不可信')
continue
acc_list.append(acc)
sensitive_list.append(sensitive)
specify_list.append(specify)
# txt = open(model_name + '.txt', 'a+')
# txt.write('准确度, 灵敏度, 特异度: '
# + str(round(acc, 3)) + ' '
# + str(round(sensitive, 3)) + ' '
# + str(round(specify, 3)) + '\n')
# 计算一定多个模型的平均准确度
mean_acc = np.mean(acc_list)
mean_sensitive = np.mean(sensitive_list)
mean_specify = np.mean(specify_list)
print('====================================')
print('平均准确度, 平均灵敏度, 平均特异度: ', mean_acc, mean_sensitive, mean_specify)
if __name__ == '__main__':
res = ResNetTest(weight_file=None)
res.predict()
res.sess.close()
# res.vis_feature_map()