bubbliiiing/Siamese-pytorch: 这是一个孪生神经网络(Siamese network)的库,可进行图片的相似性比较。 (github.com)https://github.com/bubbliiiing/Siamese-pytorch以上是原始代码地址:站在巨人的肩膀上!
1. 对predict.py的改动: 可对比多幅图片,从而达到分类的效果
# -*- coding:utf-8 -*-
from function import *
from PIL import Image
from siamese import Siamese
# 初始化
prob = []
label = []
result = {}
# 实例化网络
model = Siamese()
# 对比(相似度计算)函数
def TwoImg(img, img2):
try:
img2 = Image.open(img2)
except:
print('Image_2 Open Error! Try again!')
# 此时img已经转化为PIL格式
probability = model.detect_image(img, img2) # 主要函数(孪生网络模型函数)
# print(probability)
return probability
# TODO 1.待检测图片地址
def inputImg():
img = input(f'待检测图片地址:')
if img == '':
img = 'input/Angelic.png' # (待)检测图片
return img
#识别函数
def recognition(img):
# 检验待检测图片是否存在
try:
img = Image.open(img)
except:
print('Image Open Error! Try again!')
# TODO 2.标签图片(label images)地址
img_path = rf'text'
# 获取标签图片名称等信息
list = GetFileList(img_path, [])
# 与每张标签图片进行对比,返回相似度
for path in list:
first, last = os.path.splitext(path)
img2 = path # 标签图片
# 对比,输入图片地址,返回相似度(tensor格式)
prob_two = TwoImg(img, img2)
# 将相似度tensor格式(prob_two)转化为float格式(similarity)
similarity = prob_two.tolist()[0]
# 将相似度结果加入相似度列表(prob)
prob.append(similarity)
# 将图片名称加入标签列表(label)
label.append(first)
# 打印相似度和图片名称列表
print(f'\n1.标签为:{label}\n'
f'2.与待检测图片相似度分别为:{prob}')
# 对相似度列表进行softmax转换
prob_sm = softmax(prob)
print(f'3.各个类别的可能性:{prob_sm}\n')
# 打印各个类别的可能性百分比
for i in range(len(label)):
print(f'{label[i]}: {prob_sm[i] * 100:.2f}%')
# 编制result字典:1.key为标签 label[i] 2.value为经softmax的相似度(即概率)prob_sm[i]
result[f'{label[i]}'] = prob_sm[i]
# 最有可能的预测结果
MAX = sorted(result,
key=result.get,
reverse=True)[0]
print(f'\n最有可能的结果是【{MAX}】,'
f'有{result[MAX] * 100:.2f}%的可能性.\n')
if __name__ == "__main__":
while True:
img = inputImg()
if img == 'exit':
break
recognition(img)
ps.相关函数存放在另一个.py模块 :function.py
import os
import numpy as np
import os.path
# 输入文件夹地址和[],返回包含子文件名称的列表
def GetFileList(dir, fileList):
newDir = dir
if os.path.isfile(dir):
fileList.append(dir)
elif os.path.isdir(dir):
for s in os.listdir(dir):
# 如果需要忽略某些文件夹,使用以下代码
# if s == "xxx":
# continue
newDir = os.path.join(dir, s)
GetFileList(newDir, fileList)
return fileList
# softmax转换 输入输出为列表
def softmax(x):
"""Compute the softmax in a numerically stable way."""
x = x - np.max(x)
exp_x = np.exp(x)
softmax_x = exp_x / np.sum(exp_x)
return softmax_x
if __name__ == '__main__':
# 测试GetFileList函数
list = GetFileList('text', [])
for path in list:
print(path)