代码地址:https://github.com/Microsoft/singleshotpose
论文地址:Real-Time Seamless Single Shot 6D Object Pose Prediction
五、训练
python3 train.py --datacfg kitti/kitti.data --modelcfg kitti/yolo-pose.cfg --initweightfile cfg/darknet19_448.conv.23 --pretrain_num_epochs 15
训练之后的权重和损失会保存在backup里面
六、损失曲线
import numpy as np
import matplotlib.pyplot as plt
data = np.load('/root/work/21server/costs.npz', encoding='bytes', allow_pickle=True)
# print(data.files)
#['training_iters', 'testing_errors_angle', 'testing_iters', 'training_losses', 'testing_errors_pixel', 'testing_accuracies']
# print(data['testing_iters'])
# print(data['testing_iters'].shape)
# print(data['testing_accuracies'])
# print(data['testing_accuracies'].shape)
x1_data = data['testing_iters']
y1_data = data['testing_losses']
x2_data = data['training_iters']
y2_data = data['training_losses']
plt.figure(figsize=(6,6))
plt.plot(x1_data, y1_data,'r-',x2_data, y2_data,'b*')
plt.ylabel('loss')
plt.xlabel('iter')
plt.legend(['train', 'validation'], loc='upper right')
plt.savefig("loss.jpg")
六、测试结果可视化
1、新建backup/kitti/test/gt
和backup/kiiti/test/pr
,真实的角点坐标和预测的角点坐标会分别存储在这里;
2、valid.py
的测试代码部分,稍微修改了一下;
3、测试命令:
python3 valid.py --datacfg kitti/kitti.data --modelcfg kitti/yolo-pose.cfg --weightfile backup/kitti/model_backup.weights
4、
import cv2
import os
import numpy as np
label_path = "/workspace/darknet/threedbbox/singleshotpose/LINEMOD/duck/labels/000000.txt"
pre_path = "/workspace/darknet/threedbbox/singleshotpose/backup/duck/test/pr/corners_0001.txt"
image_path = "/workspace/darknet/threedbbox/singleshotpose/LINEMOD/duck/JPEGImages/000000.jpg"
image = cv2.imread(image_path)
w = image.shape[1]
h = image.shape[0]
list_label1 = []
list_label2 = []
list_label3 = []
list_label4 = []
# with open(label_path, 'r') as f:
# for line in f.readlines():
# line_strs = line.split()
# cls_id = line_strs[0]
# for i in range(1,9):
# x1 = int(float(line_strs[i*2+1])*w)
# y1 = int(float(line_strs[i*2+2])*h)
# list_label1.append((x1,y1))
# list_label2.append((x1,y1))
# cv2.circle(image,(x1,y1),3,(0,0,255),-1) # red
# print(list_label1)
# for label1 in list_label1:
# for label2 in list_label2:
# cv2.line(image, (label1), (label2), (0, 0, 255), 1)
# cv2.imwrite('000000_ground.jpg', image)
with open(pre_path, 'r') as f:
for line in f.readlines():
line_strs = line.split()
x1 = int(float(line_strs[0]))
y1 = int(float(line_strs[1]))
list_label3.append((x1,y1))
list_label4.append((x1,y1))
cv2.circle(image,(x1,y1),3,(255,0,0),-1) # red
print(list_label3)
for label3 in list_label3:
for label4 in list_label4:
cv2.line(image, (label3), (label4), (0, 0, 255), 1)
cv2.imwrite('000000_pred.jpg', image)