pytorch cross-entropy 用于图像分割(多类别)

1、predict_result  and mask

#我们假设每个像素点可能属于3个类别之一:0 or 1 or 2

#注意mask的shape=(1,2,2),我们要指定batch_size:1,但是不用指定channel,不用对mask做one_hot处理
#四个像素:four piexls
mask = torch.tensor([[[1,2],[2,0]]])

#网络的输出shape=(1,3,2,2),batch_size=1,channel=3

pre = torch.tensor(
        [[
          [[ 1.4146,  0.1446],
          [ 1.8990, -1.1429]],

         [[ 0.9828, -1.8893],
          [-1.5256, -1.8724]],

         [[-1.2521,  1.8987],
          [-0.2531, -0.1878]]
         ]])

 

2、调用cross-entropy

cross_entropy(pre,mask)
#output of above function
tensor(1.2123)

 

3、验证pytorch算的对不对(调包侠可以不用看*——*)

                                                                                                 2

每个像素的cross-entropy    :p1 = crossentropy for pixel 1=-\sum yi*log(pi)     

                                                                                                i=0

result = \frac{p1+p2+p3+p4}{4}

#1、对输出计算softmax
softmax = F.softmax(pre,1)
print(softmax)

---------output:
tensor([[[[0.5818, 0.1447],
          [0.8705, 0.2450]],

         [[0.3778, 0.0189],
          [0.0283, 0.1181]],

         [[0.0404, 0.8363],
          [0.1012, 0.6368]]]])

#三个通道的对应值相加都为1,例如0.581+0.3778+0.0404

 

#2、对softmax每个值求自然对数


log_softmax=torch.log(softmax)


----------------output:
tensor([[[[-0.5416, -1.9328],
          [-0.1387, -1.4064]],

         [[-0.9734, -3.9667],
          [-3.5633, -2.1359]],

         [[-3.2083, -0.1787],
          [-2.2908, -0.4513]]]])

 

#3、mask为[[[1,2],[2,0]]],故我们以mask为index从log_softmax中挑选出这四个值,因为对mask进行softmax编码后只有index对应的值为1,其余都为0,和log_softmax中的值相乘都为0不必考虑,分析下cross_entropy的公式就明白

output=0.9734+0.1787+2.2908+1.4064
average_output = output/4
               = 1.2123249999999999
               = F.cross_entropy(pre,mask)

 

 

ps:用keras时我完全没想这么多,毕竟太好用了。。。。。。

 

 

 

 

 

 

 

  • 3
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值