loop2 = tqdm(enumerate(train_loader), total=len(train_loader))
for batch_idx, data in loop2:
img,label,heatmap_t,smoothlabel=data['img'].cuda(args.gpu, non_blocking=True),data['label'],data['heatmap_t'].cuda(args.gpu, non_blocking=True),data['smooth_label']
if 1:
center_loss, scale_loss, offset_loss, theta_loss = model({'img':img , 'label':label , 'heatmap_t':heatmap_t,'smooth_label':smoothlabel})
total_loss = center_loss + scale_loss + offset_loss + 1.0*theta_loss
optimizer.zero_grad()
total_loss.backward()
for p in model.parameters():
torch.nn.utils.clip_grad_norm_(p,10)
optimizer.step()
train_loss.append(float(total_loss))
if args.rank==0:
str_template = \
'{}\{} | Center loss: {:1.5f} | scale loss: {:1.5f} | offset loss: {:1.5f}| theta loss:{:1.5f} | running loss: {:1.5f}'
loop2.set_description(str_template.format(
epoch, batch_idx, float(center_loss), float(scale_loss), float(offset_loss),float(theta_loss), np.mean(train_loss))
)
f.write(str(float(center_loss))) , f.write(" ") , f.write(str(float(scale_loss))) , f.write(" ") , f.write(str(float(offset_loss))), f.write(" ") ,f.write(str(float(theta_loss))), f.write(" "), f.write(str(float(np.mean(train_loss))))
f.write('\n')
tqbm按数据流实施更新
最新推荐文章于 2024-10-17 17:17:28 发布