pytorch多标签分类模型的单张图像推理

import models
import torch
import cv2
import os
import time
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as transforms

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# TODO 分类训练的标签类型,顺序很重要,一定和模型训练时保持一致!!!
CLASSNAMES = ['1','2','3','4']

def inference_one(model, img_path, out_path):
    '''pytorch多标签分类模型的单张图像推理'''
    _, imgname = os.path.split(img_path)
    print(imgname)
    # ==========(1)使用PIL进行测试的代码=====================================
    transform_valid = transforms.Compose([transforms.Resize((256, 256), interpolation=2), transforms.ToTensor()])
    img = Image.open(img_path)
    img_ = transform_valid(img).unsqueeze(0)  # 拓展维度

    # # ==========(2)使用opencv读取图像的测试代码,若使用opencv进行读取,将上面(1)注释掉==========
    # img = cv2.imread(img_path)
    # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    # img = cv2.resize(img, (256, 256))
    # img_ = torch.from_numpy(img).float().permute(2, 0, 1).unsqueeze(0)/255
    
    img_ = img_.to(device)
    outputs = model(img_)
    outputs = torch.sigmoid(outputs)    # 在训练时也这样设置:多标签用 sigmod 单标签用softmax, 
    print(outputs)

    y_pred = []
    for i in range(len(outputs[0])):
        if  float(outputs[0][i]) > 0.5:
            y_pred.append(1)
        else: 
            y_pred.append(0)
    # print(imgname, y_pred)

    # 输出概率最大的类别
    _, indices = torch.max(outputs, 1)
    percentage = outputs[0] * 100
    perc = percentage[int(indices)].item()
    result = CLASSNAMES[indices]
    print('max predicted:', result, round(perc,2))

    # 输出从大到小排序的预测结果
    _, indices = torch.sort(outputs, descending=True)
    percentage = outputs[0] * 100
    predict = [(CLASSNAMES[idx], round(percentage[idx].item(),2)) for idx in indices[0]]
    print('all predicted:', predict)

    # 可视化 top3
    image = img_.squeeze(0)
    image = image.detach().cpu().numpy()
    image = np.transpose(image, (1, 2, 0))
    plt.imshow(image)
    plt.axis('off')
    plt.title(f"{predict[:3]}")
    plt.savefig(os.path.join(out_path,imgname.replace(".jpg","_res.jpg")))
    plt.show()

    return y_pred


if __name__ == "__main__":
    start_time =  time.time()

    model = models.model(pretrained=False, requires_grad=False).to(device)
    checkpoint = torch.load('/workspace/code/multi-label_image_classification/outputs/model_1117.pth')
    model.load_state_dict(checkpoint['model_state_dict'])    # load model weights state_dict
    model.eval()

    root_path = "/workspace/code/multi-label_image_classification/input/biaoji-classifier/Multi_Label_dataset/Images"
    out_path = "/workspace/code/multi-label_image_classification/resimg"
    os.makedirs(out_path, exist_ok=True)

    txt = open("/workspace/code/multi-label_image_classification/resimg.txt", "w")
    for filename in os.listdir(root_path):
        if ".jpg" in filename:
            img_path = os.path.join(root_path, filename)
            y_pred = inference_one(model, img_path, out_path)
            txt.writelines(f"{filename}\t{y_pred}\n")
            txt.flush()
    txt.close()

    end_time =  time.time()
    print("cost times:", end_time - start_time)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值