tensor基本运算
考虑如下代码:
mask_score_f = mask_score[:, 0, :, :].unsqueeze(1) #tensor:(N,2,W,H) 水平集函数 \phi
mask_score_b = mask_score[:, 1, :, :].unsqueeze(1) #tensor:(N,C,W,H) 待分割图像
interior_ = torch.sum(mask_score_f * lst_target, (2, 3)) / torch.sum(mask_score_f, (2, 3)).clamp(min=0.00001) # 内部常值 c_1
exterior_ = torch.sum(mask_score_b * lst_target, (2, 3)) / torch.sum(mask_score_b, (2, 3)).clamp(min=0.00001) # 外部常值 c_2
interior_region_level = torch.pow(lst_target - interior_.unsqueeze(-1).unsqueeze(-1), 2)
exterior_region_level = torch.pow(lst_target - exterior_.unsqueeze(-1).unsqueeze(-1), 2)
region_level_loss = interior_region_level*mask_score_f + exterior_region_level*mask_score_b
level_set_loss = torch.sum(region_level_loss, (1, 2, 3))/lst_target.shape[1] # 通道平均;
初次见到它很不清楚含义,经过一下午print各种中间量才知道一些含义。不妨设
mask_score = torch.ones(3,2,4,5)
lst_target = torch.ones(3,3,4,5)
(必须创建tensor才能用unsqueeze函数,否则建立数组后要先转成tensor。)
第一行: mask_score[:, 0, :, :]是取第1维的第0个元素,现在tensor大小为
3
∗
4
∗
5
3*4*5
3∗4∗5;unsqueeze(1)在第1维前加一个维度,现在tensor大小为
3
∗
1
∗
4
∗
5
3*1*4*5
3∗1∗4∗5。
第二行: 同上,得tensor大小为
3
∗
1
∗
4
∗
5
3*1*4*5
3∗1∗4∗5。
第三行: mask_score_f大小为
3
∗
1
∗
4
∗
5
3*1*4*5
3∗1∗4∗5,而lst_target大小为
3
∗
3
∗
4
∗
5
3*3*4*5
3∗3∗4∗5,由于torch的broadcast机制,mask_score_f可以和lst_target相乘,相当于mask_score_f和每一个第一维的let_target都逐元素乘了一遍,进行了3遍。 相乘的结果是大小为
3
∗
3
∗
4
∗
5
3*3*4*5
3∗3∗4∗5的张量,然后得interior_是大小为
3
∗
3
3*3
3∗3的全1张量。clamp函数在这里是保证最小值的,即tensor里小于0.00001的值全改成0.00001。注意:若mask_score_f大小为
3
∗
2
∗
4
∗
5
3*2*4*5
3∗2∗4∗5,即便是仍用2,3维相乘再相加仍会报错(想想也是合理的,如果是
3
∗
1
∗
4
∗
5
3*1*4*5
3∗1∗4∗5的话可以把第1维复制三份)。
第四行: 同上,得大小为
3
∗
3
3*3
3∗3的全1张量exterior_。
第五行: lst_target大小为
3
∗
3
∗
4
∗
5
3*3*4*5
3∗3∗4∗5,interior_大小为
3
∗
3
3*3
3∗3。unsqueeze(-1)是在最后一维再拓展一维,作用两次变成
3
∗
3
∗
1
∗
1
3*3*1*1
3∗3∗1∗1大小。张量大小不一样不影响,只不过是最后两维分别复制四份和五份再做减法。再例如:a = torch.ones(3,3,2,2) ; b = torch.ones(3,3,1,1) * 2 ; c = b - a。则c为大小为
3
∗
3
∗
2
∗
2
3*3*2*2
3∗3∗2∗2的全1张量。pow(a,2)是求a的平方操作(逐分量求,跟矩阵乘法什么的无关)。
第六行: 同上,得大小为
3
∗
3
∗
4
∗
5
3*3*4*5
3∗3∗4∗5的张量。
第七行: 张量interior_region_level大小为
3
∗
3
∗
4
∗
5
3*3*4*5
3∗3∗4∗5,mask_score_f大小为
3
∗
1
∗
4
∗
5
3*1*4*5
3∗1∗4∗5,mask_score_f第1维复制三份,逐分量相乘即可。再例如:a = torch.ones(3,3,2,2) * 2 ; b = torch.ones(3,3,1,1) * 3 ; h= a * b。则h为大小为
3
∗
3
∗
2
∗
2
3*3*2*2
3∗3∗2∗2的全6张量。后面同理。
第八行: 现在region_level_loss大小为
3
∗
3
∗
4
∗
5
3*3*4*5
3∗3∗4∗5,把1-3维的数字都加起来(即后面三维,因为最前面是第0维)。lst_target.shape[1]表示lst_target第1维大小,它等于3(即
3
∗
3
∗
4
∗
5
3*3*4*5
3∗3∗4∗5的第二个3)。这个3代表通道数,所以意义是求通道平均。level_set_loss是一维张量,它仅有的一个维度(即第0维)有3个元素,这个3代表数据量大小N。理解一下的话就是level_set_loss代表每一批数据每一个通道平均水平集loss,也就是把空间上(x,y)逐点的loss相加。
注:如果是矩阵乘法,使用torch.matmul。
总结一下的话就是:如果一个维度只有1个元素,那么它做运算时可以复制,这点比较好,这样重点关注tensor的维度就可以了。
这八行代码结合原文,上午看了一小时看似明白了,实际完全没看明白,下午还是决定一点一点试,因为生成tensor数据比较方便,才算把代码的含义搞清楚。编程真不是一件容易的事情,自己水平有限,还是应该笨鸟先飞,多看多写。