使用HRnet训练自己的模型并检测

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import cv2 as cv
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from torch.autograd import Variable
from config import config
from config import update_config
from PIL import Image
import argparse
from models import cls_hrnet



def prediect():

    # loda HRNET
    config.merge_from_file('C:\\Users\\User\\Desktop\\HRnet-class\\HRNet-Classification\\experiments\\cls_hrnet_w18_small_v2_sgd_lr5e-2_wd1e-4_bs32_x100.yaml')
    config.freeze()

    #
    # parser = argparse.ArgumentParser(description='Train network')
    # parser.add_argument('--TEST.MODEL_FILE',
    #                     help='model path',
    #                     type = str,
    #                     default='')
    # args = parser.parse_args()
    # update_config(config,args)

    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED
    hrnet = cls_hrnet.get_cls_net(config)
    print('*********************************')
    print(config.TEST.MODEL_FILE)

    if config.TEST.MODEL_FILE:
        hrnet.load_state_dict(torch.load(config.TEST.MODEL_FILE))
    else:
        print('没找到模型文件')

    gpus = list(config.GPUS)
    hrnet = torch.nn.DataParallel(hrnet, device_ids=gpus).cuda()
    hrnet.eval()

    Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

    pli_img_path = r'C:\Users\User\Desktop\HRnet-class\HRNet-Classification\imagenet\images\train\normal\im0088.jpg'
    pil_img = Image.open(pli_img_path)

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    input = transforms.Compose([
        transforms.Resize(int(config.MODEL.IMAGE_SIZE[0] / 0.875)),
        transforms.CenterCrop(config.MODEL.IMAGE_SIZE[0]),
        transforms.ToTensor(),
        normalize,
    ])(pil_img)
    input = Variable(torch.unsqueeze(input, dim=0).float(), requires_grad=False)
    # switch to evaluate mode
    cls_pred = 0
    with torch.no_grad():
        output = hrnet(input)

        # print(output)

        # free image
        torch.cuda.empty_cache()

        cls_pred = output.argmax(dim=1)
        print(cls_pred)

        if cls_pred ==0:
            print('This pic belong to class:')
            print('fall')

        if cls_pred ==1:
            print('This pic belong to class:')
            print('normal')



prediect()









评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值