一直没有找到很好的计数代码,自己做了一点尝试修改detect.py文件。
https://github.com/ultralytics/yolov5/blob/v6.1/detect.py
定位到原代码中的 # Print results,作以下修改。
还不能实现单个类内的从1到N的计数,有会的朋友请不吝赐教。(这个2023.6.3已解决,开心)
肉眼看了一下缝针的数量,8个排针+6个三角针+1个散放的针=15个,数量是欠精确的。框框实在太密集了,以至于无法判断倒底是左侧的算多了,还是右侧的算多了。
if len(det):
# Rescale boxes from img_size to im0 size
det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round()
Results = "Results: "
# Print results
for c in det[:, -1].unique():
n = (det[:, -1] == c).sum() # detections per class
s += '\n'+f"{n} {names[int(c)]}{'s' * (n > 1)}" # add to string
Results += '\n'+f"{n} {names[int(c)]}" # TODO 加了一个变量Results
#s += f"{n} {names[int(c)]}{'s' * (n > 1)}, +" # add to string
# Write results
for *xyxy, conf, cls in reversed(det):
if save_txt: # Write to file
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
with open(txt_path + '.txt', 'a') as f:
f.write(('%g ' * len(line)).rstrip() % line + '\n')
if save_img or save_crop or view_img: # TODO # Add bbox to image
c = int(cls) # integer class分类数
label = None if hide_labels else (names[c] if hide_conf else f'{names[c]} {conf:.2f} {(det[:, -1] == c).sum()}') # TODO 标签计数展示加在了末尾
annotator.box_label(xyxy, label, color=colors(c, True))
if save_crop:
save_one_box(xyxy, imc, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True)
# Stream results
im0 = annotator.result()
if view_img:
cv2.imshow(str(p), im0)
cv2.waitKey(1) # 1 millisecond
# Save results (image with detections)
if save_img:
# 需用循环的方式显示多行,因为cv2.putText对换行转义符'\n'显示为'?'
y0, dy = 50, 40
for i, txt in enumerate(Results.split('\n')):
y = y0 + i * dy
cv2.putText(im0, txt, (50, y), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 0), 2, 2)
#cv2.putText(im0, Results, (10, 10), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 3) # TODO 左上角显示检测标签和数量
原图和检测结果如下
在yolov5ncnn上,canvas那一块修改代码后的计数效果图
修改showObjects函数
public void showObjects(YoloV5Ncnn.Obj[] objects)
{
if (objects == null)
{
picture.setImageBitmap(bitmap);
return;
}
// draw objects on bitmap
Bitmap rgba = bitmap.copy(Bitmap.Config.ARGB_8888, true);
final int[] colors = new int[] {
Color.rgb( 54, 67, 244),
Color.rgb( 99, 30, 233),
Color.rgb(176, 39, 156),
Color.rgb(183, 58, 103),
Color.rgb(181, 81, 63),
Color.rgb(243, 150, 33),
Color.rgb(244, 169, 3),
Color.rgb(212, 188, 0),
Color.rgb(136, 150, 0),
Color.rgb( 80, 175, 76),
Color.rgb( 74, 195, 139),
Color.rgb( 57, 220, 205),
Color.rgb( 59, 235, 255),
Color.rgb( 7, 193, 255),
Color.rgb( 0, 152, 255),
Color.rgb( 34, 87, 255),
Color.rgb( 72, 85, 121),
Color.rgb(158, 158, 158),
Color.rgb(139, 125, 96)
};
Canvas canvas = new Canvas(rgba);
Paint paint = new Paint();
paint.setStyle(Paint.Style.STROKE);
paint.setStrokeWidth(4);
Paint textbgpaint = new Paint();
textbgpaint.setColor(Color.WHITE);
textbgpaint.setStyle(Paint.Style.FILL);
Paint textpaint = new Paint();
textpaint.setColor(Color.BLACK);
textpaint.setTextSize(26);
textpaint.setTextAlign(Paint.Align.LEFT);
//增加一个笔刷textpaint2
Paint textpaint2 = new Paint();
textpaint2.setColor(Color.RED);
textpaint2.setTextSize(26);
textpaint2.setTextAlign(Paint.Align.LEFT);
for (int i = 0; i < objects.length; i++)
{
paint.setColor(colors[i % 19]);
canvas.drawRect(objects[i].x, objects[i].y, objects[i].x + objects[i].w, objects[i].y + objects[i].h, paint);
// draw filled text inside image
{
list.add(objects[i].label);//将标签添加到列表中
//下面的3行代码,如果不需要类内计算,直接用String text = objects[i].label即可,获取key为'A'的value,您可以使用以下代码:value_of_a = my_map.get('A');不要用value_of_a = my_map['A'];如果该键不存在于字典中,则会抛出KeyError异常。
Map<String, Integer> temp_map = new HashMap<>();
for (String l : list) {temp_map.merge(l, 1, Integer::sum);}
String text = objects[i].label + temp_map.get(objects[i].label);//获取对象的标签+出现的次数
float text_width = textpaint.measureText(text);
float text_height = -textpaint.ascent() + textpaint.descent();
float x = objects[i].x;
float y = objects[i].y - text_height;
if (y < 0)
y = 0;
if (x + text_width > rgba.getWidth())
x = rgba.getWidth() - text_width;
canvas.drawRect(x, y, x + text_width, y + text_height, textbgpaint);
canvas.drawText(text, x, y - textpaint.ascent(), textpaint);
}
}
String Number = "物品总数:" + objects.length + "...";//字符串变量Number保存物品总数
Map<String, Integer> map = new HashMap<>();//Map型变量保存
for (String l : list) {
map.merge(l, 1, Integer::sum);
}
Date date = new Date();
canvas.drawText(String.format("%tc%n",date), 30, 30, textpaint2);//第一行字,显示当前时间戳
canvas.drawText(Number, 30, 60, textpaint2);//显示总物品数量
canvas.drawText(map.toString(), 30, 90, textpaint2);//显示map,包含每个类和数量
picture.setImageBitmap(rgba);
}
这是带类内计算的效果。