首先,类先导入
from models.focal_loss import focal_loss
然后要声明类的实例化对象,接下来就可以调用类中的compute_loss函数
Focal_Loss = focal_loss()
for index in range(len(targets)):
sub_loss1 = Focal_Loss.compute_loss(output1[index].unsqueeze(0), targets[:, 0][index].unsqueeze(0).long().cuda())
sub_loss2 = Focal_Loss.compute_loss(output2[index].unsqueeze(0), targets[:, 1][index].unsqueeze(0).long().cuda())
具体参考如下博客:
https://blog.csdn.net/sinat_29699167/article/details/78350617