原文:https://zhuanlan.zhihu.com/p/28527749
import torch
gamma = torch.ones_like(focal_weight).cuda()
gamma[focal_weight > 0.5] = 0.4
gamma[focal_weight < 0.5] = 2.2
focal_weight = alpha_factor * torch.pow(focal_weight, gamma)
目标检测不行:
import torch
impor