pytorch模型的多batch推理

pytorch模型的多batch推理,这里用的是一个多标签分类模型,所以最后用的不是softmax而是sigmod。

import cv2
import time
import models
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

# TODO 分类训练的标签类型,顺序很重要,一定和模型训练时保持一致!!!
CLASSNAMES = ['1', '2’, ‘3', '4']

class TestDataset(Dataset):
    def __init__(self, imgpath):
        self.imgpath = imgpath
        self.image_names = list(os.listdir(imgpath))

    def __len__(self):
        return len(self.image_names)
    
    def __getitem__(self, index):
        path = f"{self.imgpath}/{self.image_names[index]}"
        image = Image.open(path)
        transform = transforms.Compose([
            # transforms.ToPILImage(),
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        image = transform(image)
        return {
            "id": torch.tensor(index, dtype=torch.float32),
            'image': torch.tensor(image, dtype=torch.float32),
        }
    

if __name__ == "__main__":
    start_time =  time.time()
    # # initialize the computation device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    #intialize the model
    model = models.model(pretrained=False, requires_grad=False).to(device)
    # load the model checkpoint
    checkpoint = torch.load('/workspace/code/multi-label_image_classification/outputs/model_1117.pth')
    # load model weights state_dict
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    '''多batch推理'''
    batchsize = 8
    numworkers = 4
    inputpath = "/workspace/code//input/Images"
    test_data = TestDataset(inputpath)
    test_loader = DataLoader(test_data, batch_size=batchsize, shuffle=False, sampler=None, batch_sampler=None, num_workers=numworkers)
    print(len(test_data), len(test_loader))
    
    out_path = "/workspace/code/multi-label_image_classification/resimg1"
    os.makedirs(out_path, exist_ok=True)

    # 分类结果txt
    txt = open("resimg_batch.txt", "w")
    for counter, data in enumerate(test_loader):
        id = data['id'].to(device)
        image = data['image'].to(device)
        outputs = model(image)
        for i, output in enumerate(outputs):
            # imgname = test_data.image_names[counter*batchsize+i]
            imgname = test_data.image_names[int(id[i])]         # int(id[i]) == counter*batchsize+i
            output = torch.sigmoid(output)
            output = output.detach().cpu()

    end_time =  time.time()
    print("cost times:", end_time - start_time)

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值