准备换个loss,函数,传入参数一直报错各种shape和维度上的值不匹配,最近经常听到简化问题的提倡,所以开始重视单元测试。
针对一个类,把前后的调用找到,就在当前模块文件的类下面做单元测试了:
#coding=utf-8 """ The implementation of the paper: Region Mutual Information Loss for Semantic Segmentation. """ # python 2.X, 3.X compatibility from __future__ import print_function from __future__ import division from __future__ import absolute_import import torch import torch.nn as nn import torch.nn.functional as F from losses.rmi import rmi_utils _euler_num = 2.718281828 # euler number _pi = 3.14159265 # pi _ln_2_pi = 1.837877 # ln(2 * pi) _CLIP_MIN = 1e-6 # min clip value after softmax or sigmoid operations _CLIP_MAX = 1.0 # max clip value after softmax or sigmoid operations _POS_ALPHA = 5e-4 # add this factor to ensure the AA^T is positive definite _IS_SUM = 1 # sum the loss per channel __all__ = ['RMILoss'] class RMILoss(nn.Module): """ region mutual information I(A, B) = H(A) + H(B) - H(A, B) This version need a lot of memory if do not dwonsample. """ def __init__(self, num_classes=21, rmi_radius=3, rmi_pool_way=0, rmi_pool_size=3, rmi_pool_stride=3, loss_weight_lambda=0.5, lambda_way=1): super(RMILoss, self).__init__() self.num_classes = num_classes # radius choices assert rmi_radius in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] self.rmi_radius = rmi_radius assert rmi_pool_way in [0, 1, 2, 3] self.rmi_pool_way = rmi_pool_way # set the pool_size = rmi_pool_stride assert rmi_pool_size == rmi_pool_stride self.rmi_pool_size = rmi_pool_size self.rmi_pool_stride = rmi_pool_stride self.weight_lambda = loss_weight_lambda self.lambda_way = lambda_way # dimension of the distribution self.half_d = self.rmi_radius * self.rmi_radius self.d = 2 * self.half_d self.kernel_padding = self.rmi_pool_size // 2 # ignore class self.ignore_index = 255 def forward(self, logits_4D, labels_4D): loss = self.forward_sigmoid(logits_4D, labels_4D) #loss = self.forward_softmax_sigmoid(logits_4D, labels_4D) return loss def forward_softmax_sigmoid(self, logits_4D, labels_4D): """ Using both softmax and sigmoid operations. Args: logits_4D : [N, C, H, W], dtype=float32 labels_4D : [N, H, W], dtype=long """ # PART I -- get the normal cross entropy loss normal_loss = F.cross_entropy(input=logits_4D, target=labels_4D.long(), ignore_index=self.ignore_index, reduction='mean') # PART II -- get the lower bound of the region mutual information # get the valid label and logits # valid label, [N, C, H, W] label_mask_3D = labels_4D < self.num_classes valid_onehot_labels_4D = F.one_hot(labels_4D.long() * label_mask_3D.long(), num_classes=self.num_classes).float() label_mask_3D = label_mask_3D.float() valid_onehot_labels_4D = valid_onehot_labels_4D * label_mask_3D.unsqueeze(dim=3) valid_onehot_labels_4D = valid_onehot_labels_4D.permute(0, 3, 1, 2).requires_grad_(False) # valid probs probs_4D = F.sigmoid(logits_4D) * label_mask_3D.unsqueeze(dim=1) probs_4D = probs_4D.clamp(min=_CLIP_MIN, max=_CLIP_MAX) # get region mutual information rmi_loss = self.rmi_lower_bound(valid_onehot_labels_4D, probs_4D) # add together final_loss = (self.weight_lambda * normal_loss + rmi_loss * (1 - self.weight_lambda) if self.lambda_way else normal_loss + rmi_loss * self.weight_lambda) return final_loss def forward_sigmoid(self, logits_4D, labels_4D): """ Using the sigmiod operation both. Args: logits_4D : [N, C, H, W], dtype=float32 labels_4D : [N, H, W], dtype=long """ # label mask -- [N, H, W, 1] label_mask_3D = labels_4D < self.num_classes # valid label valid_onehot_labels_4D = F.one_hot(labels_4D.long() * label_mask_3D.long(), num_classes=self.num_classes).float() label_mask_3D = label_mask_3D.float() label_mask_flat = label_mask_3D.view([-1, ]) valid_onehot_labels_4D = valid_onehot_labels_4D * label_mask_3D.unsqueeze(dim=3)#原本的维度是3,但是程序陷入睡眠状态,不知道在等什么事件 valid_onehot_labels_4D.requires_grad_(False) # PART I -- calculate the sigmoid binary cross entropy loss valid_onehot_label_flat = valid_onehot_labels_4D.view([-1, self.num_classes]).requires_grad_(False) logits_flat = logits_4D.permute(0, 2, 3, 1).contiguous().view([-1, self.num_classes]) # binary loss, multiplied by the not_ignore_mask valid_pixels = torch.sum(label_mask_flat) binary_loss = F.binary_cross_entropy_with_logits(logits_flat, #100 1 target=valid_onehot_label_flat, #20 1 weight=label_mask_flat.unsqueeze(dim=1), reduction='sum') bce_loss = torch.div(binary_loss, valid_pixels + 1.0) # PART II -- get rmi loss # onehot_labels_4D -- [N, C, H, W] probs_4D = logits_4D.sigmoid() * label_mask_3D.unsqueeze(dim=1) + _CLIP_MIN valid_onehot_labels_4D = valid_onehot_labels_4D.permute(0, 3, 1, 2).requires_grad_(False) # get region mutual information rmi_loss = self.rmi_lower_bound(valid_onehot_labels_4D, probs_4D) # add together final_loss = (self.weight_lambda * bce_loss + rmi_loss * (1 - self.weight_lambda) if self.lambda_way else bce_loss + rmi_loss * self.weight_lambda) return final_loss def rmi_lower_bound(self, labels_4D, probs_4D): """ calculate the lower bound of the region mutual information. Args: labels_4D : [N, C, H, W], dtype=float32 probs_4D : [N, C, H, W], dtype=float32 """ assert labels_4D.size() == probs_4D.size() p, s = self.rmi_pool_size, self.rmi_pool_stride if self.rmi_pool_stride > 1: if self.rmi_pool_way == 0: labels_4D = F.max_pool2d(labels_4D, kernel_size=p, stride=s, padding=self.kernel_padding) probs_4D = F.max_pool2d(probs_4D, kernel_size=p, stride=s, padding=self.kernel_padding) elif self.rmi_pool_way == 1: labels_4D = F.avg_pool2d(labels_4D, kernel_size=p, stride=s, padding=self.kernel_padding) probs_4D = F.avg_pool2d(probs_4D, kernel_size=p, stride=s, padding=self.kernel_padding) elif self.rmi_pool_way == 2: # interpolation shape = labels_4D.size() new_h, new_w = shape[2] // s, shape[3] // s labels_4D = F.interpolate(labels_4D, size=(new_h, new_w), mode='nearest') probs_4D = F.interpolate(probs_4D, size=(new_h, new_w), mode='bilinear', align_corners=True) else: raise NotImplementedError("Pool way of RMI is not defined!") # we do not need the gradient of label. label_shape = labels_4D.size() n, c = label_shape[0], label_shape[1] # combine the high dimension points from label and probability map. new shape [N, C, radius * radius, H, W] la_vectors, pr_vectors = rmi_utils.map_get_pairs(labels_4D, probs_4D, radius=self.rmi_radius, is_combine=0) la_vectors = la_vectors.view([n, c, self.half_d, -1]).type(torch.cuda.DoubleTensor).requires_grad_(False) pr_vectors = pr_vectors.view([n, c, self.half_d, -1]).type(torch.cuda.DoubleTensor) # small diagonal matrix, shape = [1, 1, radius * radius, radius * radius] diag_matrix = torch.eye(self.half_d).unsqueeze(dim=0).unsqueeze(dim=0) # the mean and covariance of these high dimension points # Var(X) = E(X^2) - E(X) E(X), N * Var(X) = X^2 - X E(X) la_vectors = la_vectors - la_vectors.mean(dim=3, keepdim=True) la_cov = torch.matmul(la_vectors, la_vectors.transpose(2, 3)) pr_vectors = pr_vectors - pr_vectors.mean(dim=3, keepdim=True) pr_cov = torch.matmul(pr_vectors, pr_vectors.transpose(2, 3)) # https://github.com/pytorch/pytorch/issues/7500 # waiting for batched torch.cholesky_inverse() pr_cov_inv = torch.inverse(pr_cov + diag_matrix.type_as(pr_cov) * _POS_ALPHA) # if the dimension of the point is less than 9, you can use the below function # to acceleration computational speed. #pr_cov_inv = utils.batch_cholesky_inverse(pr_cov + diag_matrix.type_as(pr_cov) * _POS_ALPHA) la_pr_cov = torch.matmul(la_vectors, pr_vectors.transpose(2, 3)) # the approxiamation of the variance, det(c A) = c^n det(A), A is in n x n shape; # then log det(c A) = n log(c) + log det(A). # appro_var = appro_var / n_points, we do not divide the appro_var by number of points here, # and the purpose is to avoid underflow issue. # If A = A^T, A^-1 = (A^-1)^T. appro_var = la_cov - torch.matmul(la_pr_cov.matmul(pr_cov_inv), la_pr_cov.transpose(-2, -1)) #appro_var = la_cov - torch.chain_matmul(la_pr_cov, pr_cov_inv, la_pr_cov.transpose(-2, -1)) #appro_var = torch.div(appro_var, n_points.type_as(appro_var)) + diag_matrix.type_as(appro_var) * 1e-6 # The lower bound. If A is nonsingular, ln( det(A) ) = Tr( ln(A) ). rmi_now = 0.5 * rmi_utils.log_det_by_cholesky(appro_var + diag_matrix.type_as(appro_var) * _POS_ALPHA) #rmi_now = 0.5 * torch.logdet(appro_var + diag_matrix.type_as(appro_var) * _POS_ALPHA) # mean over N samples. sum over classes. rmi_per_class = rmi_now.view([-1, self.num_classes]).mean(dim=0).float() #is_half = False #if is_half: # rmi_per_class = torch.div(rmi_per_class, float(self.half_d / 2.0)) #else: rmi_per_class = torch.div(rmi_per_class, float(self.half_d)) rmi_loss = torch.sum(rmi_per_class) if _IS_SUM else torch.mean(rmi_per_class) return rmi_loss if __name__ == '__main__': from losses import loss_factory import parser_params import argparse _PSP_AUX_WEIGHT=0.4 print("start") parser = argparse.ArgumentParser(description="PyTorch Segmentation Model Training") args = parser_params.add_parser_params(parser) print("args") criterion = loss_factory.criterion_choose(2, # 是二分类问题,但是模型中是1,这个loss中是包含背景类在内的 weight=None, loss_type=2, ignore_index=255, reduction='mean', max_iter=31500, args=args) print("cri") #当报错说100 n 和 20 n不匹配的时候,不一定要把input的维数降下来,可以把target的补起来,如:4 2 5 5 和 4 5 5 loss = criterion(torch.rand(4,2,5,5), torch.rand(4,5,5)) + _PSP_AUX_WEIGHT * F.cross_entropy( input=torch.rand(4,2,5,5).squeeze(dim=1), # squeeze和unsqueeze去除和添加的都是维度为1的部分 target=torch.rand(4,5,5).long(), ignore_index=0, reduction='mean') print("loss",loss)
如果想要保留原来改动过的文件,需要把当前文件贴到原文件同一个文件夹下,在调用criterion_choose的时候返回文件夹下的当前新文件。
感觉这种单元测试对于调试弄清调用关系之后,但是错误始终没解决的情况下,比较提升效率,对我来讲,跳转文件的调试这种状态切换非常低效。