refine_output
和 target7 的shape均为[batch,17,64,48]
,经过计算每个元素对应位置的loss后,得到[batch,17,64,48]
的refine_loss
,然后经过refine_loss.mean(dim=3).mean(dim=2)
,得到[batch,17]的loss。
MSELoss的说明如下
target7 = torch.zeros((1,num_class, output_shape[0], output_shape[1]))
refine_output = torch.ones((1,num_class, output_shape[0], output_shape[1]))
criterion2 = torch.nn.MSELoss(reduce=False)
refine_loss = criterion2(target7, refine_output)
print(refine_loss.shape)
refine_loss = refine_loss.mean(dim=3).mean(dim=2)
print(refine_loss.shape)