孪生网络代码改动

bubbliiiing/Siamese-pytorch: 这是一个孪生神经网络(Siamese network)的库,可进行图片的相似性比较。 (github.com)icon-default.png?t=M3C8https://github.com/bubbliiiing/Siamese-pytorch以上是原始代码地址:站在巨人的肩膀上!

1. 对predict.py的改动: 可对比多幅图片,从而达到分类的效果

# -*- coding:utf-8 -*-
from function import *
from PIL import Image
from siamese import Siamese

# 初始化
prob = []
label = []
result = {}
# 实例化网络
model = Siamese()

# 对比(相似度计算)函数
def TwoImg(img, img2):
    try:
        img2 = Image.open(img2)
    except:
        print('Image_2 Open Error! Try again!')

    # 此时img已经转化为PIL格式
    probability = model.detect_image(img, img2)  # 主要函数(孪生网络模型函数)
    # print(probability)
    return probability

# TODO 1.待检测图片地址
def inputImg():
    img = input(f'待检测图片地址:')
    if img == '':
        img = 'input/Angelic.png'  # (待)检测图片
    return img

#识别函数
def recognition(img):
    # 检验待检测图片是否存在
    try:
        img = Image.open(img)
    except:
        print('Image Open Error! Try again!')

    # TODO 2.标签图片(label images)地址
    img_path = rf'text'

    # 获取标签图片名称等信息
    list = GetFileList(img_path, [])

    # 与每张标签图片进行对比,返回相似度
    for path in list:
        first, last = os.path.splitext(path)
        img2 = path  # 标签图片
        # 对比,输入图片地址,返回相似度(tensor格式)
        prob_two = TwoImg(img, img2)
        # 将相似度tensor格式(prob_two)转化为float格式(similarity)
        similarity = prob_two.tolist()[0]
        # 将相似度结果加入相似度列表(prob)
        prob.append(similarity)
        # 将图片名称加入标签列表(label)
        label.append(first)

    # 打印相似度和图片名称列表
    print(f'\n1.标签为:{label}\n'
          f'2.与待检测图片相似度分别为:{prob}')

    # 对相似度列表进行softmax转换
    prob_sm = softmax(prob)
    print(f'3.各个类别的可能性:{prob_sm}\n')

    # 打印各个类别的可能性百分比
    for i in range(len(label)):
        print(f'{label[i]}: {prob_sm[i] * 100:.2f}%')
        # 编制result字典:1.key为标签 label[i] 2.value为经softmax的相似度(即概率)prob_sm[i]
        result[f'{label[i]}'] = prob_sm[i]

    # 最有可能的预测结果
    MAX = sorted(result,
                 key=result.get,
                 reverse=True)[0]

    print(f'\n最有可能的结果是【{MAX}】,'
          f'有{result[MAX] * 100:.2f}%的可能性.\n')

if __name__ == "__main__":
    while True:
        img = inputImg()
        if img == 'exit':
            break
        recognition(img)

ps.相关函数存放在另一个.py模块 :function.py

import os
import numpy as np
import os.path


# 输入文件夹地址和[],返回包含子文件名称的列表
def GetFileList(dir, fileList):
    newDir = dir
    if os.path.isfile(dir):
        fileList.append(dir)
    elif os.path.isdir(dir):
        for s in os.listdir(dir):
            # 如果需要忽略某些文件夹,使用以下代码
            # if s == "xxx":
            # continue
            newDir = os.path.join(dir, s)
            GetFileList(newDir, fileList)
    return fileList


# softmax转换 输入输出为列表
def softmax(x):
    """Compute the softmax in a numerically stable way."""
    x = x - np.max(x)
    exp_x = np.exp(x)
    softmax_x = exp_x / np.sum(exp_x)
    return softmax_x


if __name__ == '__main__':
    # 测试GetFileList函数
    list = GetFileList('text', [])
    for path in list:
        print(path)

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

wind faded

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值