2.1主模块 :
only show accuracy:如果输入 y ,仅验证分类结果的最准确率,默认标签地址为output/mac(DAGM)/2022_05_01 01_51_04,均可自行调整
[0]model path: 默认地址model path: logs/mac_5/ep098-loss0.000-val_loss0.002.pth,注意加载网络权重后,记得调整siamese.py中的input_shape
[1]待检测图片文件夹(query set)地址:
[2]待检测图片,support set 地址:默认地址为query/mac_mini
[3]标签图片,support/mac(DAGM):默认地址为support/mac(DAGM)
# -*- coding:utf-8 -*-
import datetime
from utils.function import *
from PIL import Image
from siamese import Siamese
# 实例化网络
# Model_path = input(f'model weights:')
# if Model_path == '':
# Model_path = rf'logs/animal300_200/ep039-loss0.038-val_loss0.251.pth'
# print(f'1.模型权重:{Model_path}')
model = Siamese()
# model.generate(Model_path)
# 对比(相似度计算)函数
#相似度函数
def TwoImg(img1, img2):
try:
img = Image.open(img1)
except:
print('Image Open Error! Try again!')
try:
img2 = Image.open(img2)
except:
print('Image_2 Open Error! Try again!')
# 此时img已经转化为PIL格式
probability = model.detect_image(img, img2) # 主要函数(孪生网络模型函数)
# print(probability)
return probability
# 识别函数
def recognition(Img, List2):
# 初始化
prob = []
label = []
result = {}
# 与每张标签图片进行对比,返回相似度
for Path in List2:
path = Path.split('\\')[-1]
first, last = os.path.splitext(path)
img2 = Path # 标签图片
# 对比,输入图片地址,返回相似度(tensor格式)
prob_two = TwoImg(Img, img2)
# 将相似度tensor格式(prob_two)转化为float格式(similarity)
similarity = prob_two.tolist()[0]
# 将相似度结果加入相似度列表(prob)
prob.append(similarity)
# 将图片名称加入标签列表(label)
label.append(first)
# 打印相似度和图片名称列表
# print(f'\n1.标签为:{label}\n'
# f'2.与待检测图片相似度分别为:{prob}')
# 对相似度列表进行softmax转换
prob_sm = softmax(prob)
# print(f'3.各个类别的可能性:{prob_sm}\n')
# 打印各个类别的可能性百分比
for i in range(len(label)):
print(f'{label[i]}: {prob_sm[i] * 100:.2f}%(相似度:{prob[i] * 100:.5f}%)')
# 编制result字典:1.key为标签 label[i] 2.value为经softmax的相似度(即概率)prob_sm[i]
result[f'{label[i]}'] = prob_sm[i]
# 最有可能的预测结果
MAX = sorted(result,
key=result.get,
reverse=True)[0]
print(f'\n{Img}最有可能的结果是【{MAX}】,'
f'有{result[MAX] * 100:.2f}%的可能性.\n')
return prob_sm, result, MAX
#主函数
def main_sn():
if not input(f'1.only show accuracy:') == 'y':
# TODO 1.待检测图片地址
img_path1 = input(f'待检测图片文件夹(query set)地址:') # 待检测图片文件夹地址
if img_path1 == '':
img_path1 = rf'query/mac_mini'
print(f'2.待检测图片:{img_path1}')
list1 = GetFileList(img_path1, [])
# TODO 标签图片(label images)地址
img_path2 = input(f'support set 地址:')
if img_path2 == '':
img_path2 = rf'support/mac(DAGM)'
print(f'3.标签图片:{img_path2}')
# 获取标签图片名称等信息
list2 = GetFileList(img_path2, [])
num = 1
# 保存文件夹名称
define = '%Y_%m_%d %H_%M_%S'
system_time = datetime.datetime.now().strftime(define)
supportSet = img_path2.split('/')[-1]
save_path = rf'output/{supportSet}/'
for img in list1:
# img为待检测图片地址
print(f'-----------------------------------------image[{num}]----------------------------------------')
prob_sm, result, Max = recognition(img, list2)
if not os.path.exists(f'{save_path}/{system_time}/labels'):
os.makedirs(f'{save_path}/{system_time}/labels')
with open(rf'{save_path}/{system_time}/info.txt', 'a+') as f:
f.write(f'{Max} {result}\n')
first_img, last_img = img.split('\\')
last_img_excludePNG = last_img.split('.')[:-1][0]
with open(rf'{save_path}/{system_time}/labels/{last_img_excludePNG}.txt', 'a+') as f:
f.write(f'{Max}')
num += 1
predict_accuracy(rf'{save_path}/{system_time}')
return 'OK'
predict_accuracy(rf'output/mac(DAGM)/2022_05_01 01_51_04')
if __name__ == "__main__":
main_sn()
2.2函数模块
import os
import numpy as np
import os.path
# 输入文件夹地址和[],返回包含子文件名称的列表
def GetFileList(dir, fileList):
newDir = dir
if os.path.isfile(dir):
fileList.append(dir)
elif os.path.isdir(dir):
for s in os.listdir(dir):
# 如果需要忽略某些文件夹,使用以下代码
# if s == "xxx":
# continue
newDir = os.path.join(dir, s)
GetFileList(newDir, fileList)
return fileList
# softmax转换 输入输出为列表
def softmax(x):
"""Compute the softmax in a numerically stable way."""
x = x - np.max(x)
exp_x = np.exp(x)
softmax_x = exp_x / np.sum(exp_x)
return softmax_x
def predict_accuracy(result_path):
fileList = []
result_path = rf'{result_path}\labels'
filelist = GetFileList(result_path, fileList)
total = len(filelist)
num = 0
acc = 0
for txt in filelist:
txtPath = txt.split('\\')[-1:][0]
class_num, number_classes = txtPath.split('_')
if not os.path.exists(txt):
print(f'txt文件地址不存在')
with open(txt, 'r') as f:
predict_result = f.readline()
if predict_result == class_num[5:]:
acc += 1
num += 1
print(f'[{num}]预测准确率:{float(acc/num)*100:.2f}%\n进度:{float(num/total)*100:.2f}%')
with open(rf'{result_path}/predict_accuracy.txt', 'a+') as f:
f.write(f'{predict_accuracy}')
def demo(path):
list = []
for line in open(path):
line_first = str(line.split('\\')[1])
print(line_first)
list.append(line_first)
f = open(path,'w')
for i in range(len(list)):
f.write(list[i])
if __name__ == '__main__':
# 测试GetFileList函数
list = GetFileList('../support/text', [])
for path in list:
print(path)
predict_accuracy(rf'../output/mac(DAGM)/2022_05_01 01_51_04')
#demo(r'C:\Users\cleste\Desktop\孪生网络(少样本分类)\demo.txt')
2.3 针对siamese.py模块的修改
siamese.py模块中增加了如下语句:控制打印的图像数量 if float(output) > 0.01: # 置信度对于1%的图片对才会被打印