#coding=utf-8
import torch
import numpy as np
from torch.utils.data import DataLoader
from torch.autograd import Variable
import matplotlib.pyplot as plt
import pandas as pd
from model import myconvNet
import os
import cv2
import time
from dataloader import tempDataset
import os
import yaml
import matplotlib.pyplot as plt
from importlib.abc import Loader
from PIL import Image
import cv2
import glob
import numpy as np
from torchvision import transforms
def transfunc(image):
trans = transforms.Compose([
transforms.Resize([224,224]),
transforms.ToTensor(),
])
return trans(image)
def test_plot(data_path):
# 加载模型
data_path = data_path
yaml_list = glob.glob(data_path + "/*/*.yaml")
label_list = []
for item in yaml_list:
datalabel = yaml.load(open(item), Loader=yaml.FullLoader)
for temp in datalabel['image_data']:
# import pdb;pdb.set_trace()
if len(temp['keypoints']) == 6:
label_list.append(temp)
num = len(label_list)
for index in range(len(label_list)):
image_path = data_path + '/' + label_list[index]['image_id']
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
try:
point_1_x = label_list[index]['keypoints'][0]['image_coords']['u']
point_1_y = label_list[index]['keypoints'][0]['image_coords']['v']
point_2_x = label_list[index]['keypoints'][1]['image_coords']['u']
point_2_y = label_list[index]['keypoints'][1]['image_coords']['v']
point_3_x = label_list[index]['keypoints'][2]['image_coords']['u']
point_3_y = label_list[index]['keypoints'][2]['image_coords']['v']
point_4_x = label_list[index]['keypoints'][3]['image_coords']['u']
point_4_y = label_list[index]['keypoints'][3]['image_coords']['v']
point_5_x = label_list[index]['keypoints'][4]['image_coords']['u']
point_5_y = label_list[index]['keypoints'][4]['image_coords']['v']
point_6_x = label_list[index]['keypoints'][5]['image_coords']['u']
point_6_y = label_list[index]['keypoints'][5]['image_coords']['v']
except:
import pdb;pdb.set_trace()
gt = np.array([point_1_x, point_1_y, point_2_x, point_2_y, point_3_x, point_3_y, point_4_x, point_4_y, point_5_x, point_5_y, point_6_x, point_6_y])
# gt = gt / np.array([640, 480, ])
vis_demo_img = Image.fromarray(image)
net = myconvNet(nb_out=12)
net.float().cuda()
net.eval()
net.load_state_dict(torch.load('../save_model/Iter_100_myconvnet.pt'))
test_image = transfunc(vis_demo_img)
test_image = test_image.unsqueeze(0).cuda()
# import pdb;pdb.set_trace()
pred_points = net(test_image)
pred_points = pred_points.cpu().data.numpy()
varvalue = np.array([[640, 480, 640, 480, 640, 480, 640, 480, 640, 480, 640, 480]])
pred_points = np.multiply((pred_points + 0.5), varvalue)
test_image = test_image.cpu().data.numpy()
test_image = (test_image * 255.).astype(np.uint8)
plt.imshow(vis_demo_img, cmap='gray')
# import pdb;pdb.set_trace()
# plt.scatter(gt[::2],gt[1::2], c = '#00CED1') # blue
plt.scatter(pred_points[0][::2] ,pred_points[0][1::2], c = '#DC143C') # red
filepath = "./plot_result/{}_{}_result.png".format(index, num)
plt.savefig(filepath)
# plt.show()
plt.clf()
print("have saved ./plot_result/{}_{}_result.png".format(index, num))
if __name__ == "__main__":
data_path = "../data_labelled/train"
test_plot(data_path=data_path)
test pytorch 测试模型绘图到图片
最新推荐文章于 2023-01-01 20:19:19 发布