第一步 训练结果
图上是我训练200个epoch得出来的结果,效果并不理想,因为我的训练集只有700张,但是我还是不明白为什么!!map50只有55左右,有没有大神告诉我!!
第二步 预测图片
目前我只会预测图片,不会弄文件夹。有没有大神教一教呀~
根据如下代码,创建predict.py文件,进行图片预测。
import numpy as np
from models.detr import build
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
torch.set_grad_enabled(False)
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
transform_input = transforms.Compose([transforms.Resize(800),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
img_w, img_h = size
b = box_cxcywh_to_xyxy(out_bbox)
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32, device="cuda")
return b
def plot_results(pil_img, prob, boxes, img_save_path):
plt.figure(figsize=(16, 10))
plt.imshow(pil_img)
ax = plt.gca()
colors = COLORS * 100
for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
fill=False, color=c, linewidth=3))
cl = p.argmax()
text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
ax.text(xmin, ymin, text, fontsize=9,
bbox=dict(facecolor='yellow', alpha=0.5))
plt.savefig(img_save_path)
plt.axis('off')
plt.show()
def main(checkpoint_path, img_path, img_save_path):
##加载模型
args = torch.load(checkpoint_path)['args']
model = build(args)[0]
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# 加载模型参数
model_data = torch.load(checkpoint_path)['model']
model.load_state_dict(model_data)
model.eval()
img = Image.open(img_path).convert('RGB')
size = img.size
inputs = transform_input(img).unsqueeze(0)
outputs = model(inputs.to(device))
# 这类最后[0, :, :-1]索引其实是把背景类筛选掉了
probs = outputs['pred_logits'].softmax(-1)[0, :, :-1]
# 可修改阈值,只输出概率大于0.7的物体
keep = probs.max(-1).values > 0.3
bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], size)
# 保存输出结果
ori_img = np.array(img)
plot_results(ori_img, probs[keep], bboxes_scaled, img_save_path)
if __name__ == "__main__":
CLASSES = ['N/A',"fire","background"]####这边类别只有fire一类
main(checkpoint_path="./output3/checkpoint.pth",###checkpoint的路径
img_path="coco3/test2017/images (1).jpg",####图片路径
img_save_path="predict/result1.jpg") ####预测图片的结果
第三步 结果
纪念一下第一次预测火焰检测成功的图片。
预测还行吧,不过我这个图片有点失真和马赛克了。
参考
【DETR】DETR训练VOC数据集/自己的数据集_pascal voc数据集转化成detr数据集-CSDN博客
跪谢大神!