需要用什么就看哪部分,看哪部分就写哪部分
nms前后
for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
img = img.to(device, non_blocking=True)
img = img.half() if half else img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
targets = targets.to(device)
nb, _, height, width = img.shape # batch size, channels, height, width
with torch.no_grad():
# Run model
t = time_synchronized()
# detect层三层输出拼一起的一层推断,训练的三层推断
inf_out, train_out = model(img, augment=augment) # inference and training outputs
t0 += time_synchronized() - t
# Compute loss
if compute_loss:
# :3,lbox, lobj, lcls
loss += compute_loss([x.float() for x in train_out], targets)[1][:3] # box, obj, cls
# Run NMS
# 相对xyw