复杂累加求平均用一行代码实现,呵呵,我为什么要搞这么无聊的事……
for i in range(0, bs):
sum = 0
dis_max = dis[i][which_max[i]]
sum += dis_max
dis = torch.mean(torch.tensor(sum))
dis_max = torch.mean(torch.tensor(list(dis[i][which_max[i]] for i in range(0, bs)))).cuda()
构建list然后对list进行操作