from __future__ import absolute_import from __future__ import division from __future__ import print_function import cv2 as cv import torch import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.optim import torch.utils.data import torch.utils.data.distributed import torchvision.transforms as transforms from torch.autograd import Variable from config import config from config import update_config from PIL import Image import argparse from models import cls_hrnet def prediect(): # loda HRNET config.merge_from_file('C:\\Users\\User\\Desktop\\HRnet-class\\HRNet-Classification\\experiments\\cls_hrnet_w18_small_v2_sgd_lr5e-2_wd1e-4_bs32_x100.yaml') config.freeze() # # parser = argparse.ArgumentParser(description='Train network') # parser.add_argument('--TEST.MODEL_FILE', # help='model path', # type = str, # default='') # args = parser.parse_args() # update_config(config,args) cudnn.benchmark = config.CUDNN.BENCHMARK torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC torch.backends.cudnn.enabled = config.CUDNN.ENABLED hrnet = cls_hrnet.get_cls_net(config) print('*********************************') print(config.TEST.MODEL_FILE) if config.TEST.MODEL_FILE: hrnet.load_state_dict(torch.load(config.TEST.MODEL_FILE)) else: print('没找到模型文件') gpus = list(config.GPUS) hrnet = torch.nn.DataParallel(hrnet, device_ids=gpus).cuda() hrnet.eval() Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor pli_img_path = r'C:\Users\User\Desktop\HRnet-class\HRNet-Classification\imagenet\images\train\normal\im0088.jpg' pil_img = Image.open(pli_img_path) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) input = transforms.Compose([ transforms.Resize(int(config.MODEL.IMAGE_SIZE[0] / 0.875)), transforms.CenterCrop(config.MODEL.IMAGE_SIZE[0]), transforms.ToTensor(), normalize, ])(pil_img) input = Variable(torch.unsqueeze(input, dim=0).float(), requires_grad=False) # switch to evaluate mode cls_pred = 0 with torch.no_grad(): output = hrnet(input) # print(output) # free image torch.cuda.empty_cache() cls_pred = output.argmax(dim=1) print(cls_pred) if cls_pred ==0: print('This pic belong to class:') print('fall') if cls_pred ==1: print('This pic belong to class:') print('normal') prediect()
04-02
1161
02-02
422
04-29
1921
12-26
1158
01-11
3275