FCN源码解读之score.py

转自:https://blog.csdn.net/qq_21368481/article/details/80424754

score.py是FCN中用于测试测试集/验证集的,并输出相应的像素准确度、平均准确度、mean IU和频率加权交并比(frequency weighted IU)四个指标的python文件。

score.py的源码如下:

[python] view plain copy
print ?
  1. from future import division  
  2. import caffe  
  3. import numpy as np  
  4. import os  
  5. import sys  
  6. from datetime import datetime  
  7. from PIL import Image  
  8.   
  9. def fast_hist(a, b, n):  
  10.     k = (a >= 0) & (a < n)  
  11.     return np.bincount(n  a[k].astype(int) + b[k], minlength=n*2).reshape(n, n)  
  12.   
  13. def compute_hist(net, save_dir, dataset, layer=‘score’, gt=‘label’):  
  14.     n_cl = net.blobs[layer].channels  
  15.     if save_dir:  
  16.         os.mkdir(save_dir)  
  17.     hist = np.zeros((n_cl, n_cl))  
  18.     loss = 0  
  19.     for idx in dataset:  
  20.         net.forward()  
  21.         hist += fast_hist(net.blobs[gt].data[00].flatten(),  
  22.                                 net.blobs[layer].data[0].argmax(0).flatten(),  
  23.                                 n_cl)  
  24.   
  25.         if save_dir:  
  26.             im = Image.fromarray(net.blobs[layer].data[0].argmax(0).astype(np.uint8), mode=‘P’)  
  27.             im.save(os.path.join(save_dir, idx + ’.png’))  
  28.         # compute the loss as well  
  29.         loss += net.blobs[’loss’].data.flat[0]  
  30.     return hist, loss / len(dataset)  
  31.   
  32. def seg_tests(solver, save_format, dataset, layer=‘score’, gt=‘label’):  
  33.     print ‘>>>’, datetime.now(), ‘Begin seg tests’  
  34.     solver.test_nets[0].share_with(solver.net)  
  35.     do_seg_tests(solver.test_nets[0], solver.iter, save_format, dataset, layer, gt)  
  36.   
  37. def do_seg_tests(net, iter, save_format, dataset, layer=‘score’, gt=‘label’):  
  38.     n_cl = net.blobs[layer].channels  
  39.     if save_format:  
  40.         save_format = save_format.format(iter)  
  41.     hist, loss = compute_hist(net, save_format, dataset, layer, gt)  
  42.     # mean loss  
  43.     print ‘>>>’, datetime.now(), ‘Iteration’, iter, ‘loss’, loss  
  44.     # overall accuracy  
  45.     acc = np.diag(hist).sum() / hist.sum()  
  46.     print ‘>>>’, datetime.now(), ‘Iteration’, iter, ‘overall accuracy’, acc  
  47.     # per-class accuracy  
  48.     acc = np.diag(hist) / hist.sum(1)  
  49.     print ‘>>>’, datetime.now(), ‘Iteration’, iter, ‘mean accuracy’, np.nanmean(acc)  
  50.     # per-class IU  
  51.     iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))  
  52.     print ‘>>>’, datetime.now(), ‘Iteration’, iter, ‘mean IU’, np.nanmean(iu)  
  53.     freq = hist.sum(1) / hist.sum()  
  54.     print ‘>>>’, datetime.now(), ‘Iteration’, iter, ‘fwavacc’, \  
  55.             (freq[freq > 0] * iu[freq > 0]).sum()  
  56.     return hist  
from future import division 
import caffe
import numpy as np
import os
import sys
from datetime import datetime
from PIL import Image

def fast_hist(a, b, n):
k = (a >= 0) & (a < n)
return np.bincount(n * a[k].astype(int) + b[k], minlength=n**2).reshape(n, n)

def compute_hist(net, save_dir, dataset, layer='score', gt='label'):
n_cl = net.blobs[layer].channels
if save_dir:
os.mkdir(save_dir)
hist = np.zeros((n_cl, n_cl))
loss = 0
for idx in dataset:
net.forward()
hist += fast_hist(net.blobs[gt].data[0, 0].flatten(),
net.blobs[layer].data[0].argmax(0).flatten(),
n_cl)

    if save_dir:
        im = Image.fromarray(net.blobs[layer].data[0].argmax(0).astype(np.uint8), mode='P')
        im.save(os.path.join(save_dir, idx + '.png'))
    # compute the loss as well
    loss += net.blobs['loss'].data.flat[0]
return hist, loss / len(dataset)

def seg_tests(solver, save_format, dataset, layer=’score’, gt=’label’):
print ‘>>>’, datetime.now(), ‘Begin seg tests’
solver.test_nets[0].share_with(solver.net)
do_seg_tests(solver.test_nets[0], solver.iter, save_format, dataset, layer, gt)

def do_seg_tests(net, iter, save_format, dataset, layer=’score’, gt=’label’):
n_cl = net.blobs[layer].channels
if save_format:
save_format = save_format.format(iter)
hist, loss = compute_hist(net, save_format, dataset, layer, gt)
# mean loss
print ‘>>>’, datetime.now(), ‘Iteration’, iter, ‘loss’, loss
# overall accuracy
acc = np.diag(hist).sum() / hist.sum()
print ‘>>>’, datetime.now(), ‘Iteration’, iter, ‘overall accuracy’, acc
# per-class accuracy
acc = np.diag(hist) / hist.sum(1)
print ‘>>>’, datetime.now(), ‘Iteration’, iter, ‘mean accuracy’, np.nanmean(acc)
# per-class IU
iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
print ‘>>>’, datetime.now(), ‘Iteration’, iter, ‘mean IU’, np.nanmean(iu)
freq = hist.sum(1) / hist.sum()
print ‘>>>’, datetime.now(), ‘Iteration’, iter, ‘fwavacc’, \
(freq[freq > 0] * iu[freq > 0]).sum()
return hist

详细解读如下:

(1)fast_hist()函数

[python] view plain copy
print ?
  1. ”’ 
  2. 产生n×n的分类统计表 
  3. 参数a:标签图(转换为一行输入),即真实的标签 
  4. 参数b:score层输出的预测图(转换为一行输入),即预测的标签 
  5. 参数n:类别数 
  6. ”’  
  7. def fast_hist(a, b, n):  
  8.     #k为掩膜(去除了255这些点(即标签图中的白色的轮廓),其中的a>=0是为了防止bincount()函数出错)  
  9.     k = (a >= 0) & (a < n)   
  10.     #bincount()函数用于统计数组内每个非负整数的个数  
  11.     #详见https://docs.scipy.org/doc/numpy/reference/generated/numpy.bincount.html  
  12.     return np.bincount(n  a[k].astype(int) + b[k], minlength=n*2).reshape(n, n)  
''' 
产生n×n的分类统计表
参数a:标签图(转换为一行输入),即真实的标签
参数b:score层输出的预测图(转换为一行输入),即预测的标签
参数n:类别数
'''
def fast_hist(a, b, n):
#k为掩膜(去除了255这些点(即标签图中的白色的轮廓),其中的a>=0是为了防止bincount()函数出错)
k = (a >= 0) & (a < n)
#bincount()函数用于统计数组内每个非负整数的个数
#详见https://docs.scipy.org/doc/numpy/reference/generated/numpy.bincount.html
return np.bincount(n * a[k].astype(int) + b[k], minlength=n**2).reshape(n, n)

此函数用于产生n*n的分类统计表,还不理解的可以看如下分析:

假如输入的标签图a是3*3的,如下左图,图中的数字表示该像素点的归属,即每个像素点所属的类别(其中n=3,即共有三种类别);预测标签图b的大小和a相同,如右图所示(图中的数字也代表每个像素点的类别归属)。

                   

直观上看,b中预测的标签有两个像素点预测出错,即

http://www.w3.org/1998/Math/MathML&quot; display="block">b01,b20” role=”presentation”>b01,b20b01,b20

源码中的这句语句是精华:np.bincount(n * a[k].astype(int) + b[k], minlength=n**2)

其作用是产生一行n*n个元素的向量,向量中的每个元素存储统计结果,假如该向量为d,则其中的d(i*n+j)表示预测结果为类别j,实际标签为类别i的所有像素点的数目。

将上述的a、b和n输入fast_hist(a, b, n),所产生的d为:d=(3,0,0,0,2,1,0,1,2),其中的d(1*3+1)=d(4)表示预测类别为1,实际标签也为1的所有像素点数目为2。

通过reshape(n, n)将向量d转换为3*3的矩阵,其结果如下表(该矩阵即为下表中的绿色部分):


其中绿色的3*3表格统计的含义,拿数字3所在的这一格为例,即预测标签中被预测为类别0的且其真实标签也为0的所有像素点数目之和。

上述表格有几点需要注意的是(这三条是用于计算一开始所讲的四个指标的基础):

①绿色表格中对角线元素上的数字即为该类别预测正确的像素点数目,非对角线元素都是预测错误的,拿最后一行的数字1为例,其含义即为有一个原本应属于类别2的像素点被错误地预测为类别1;

②绿色表格的每一行求和得到的数字的含义是真实标签中属于某一类别的所有像素点数目,拿第一行为例,3+0+0=3,即真实属于类别0的像素点一共3个;

③绿色表格的每一列求和得到的数字的含义是预测为某一类别的所有像素点数目,拿第二列为例,0+2+1=3,即预测为类别1的所有像素点共有3个。

(2)compute_hist()函数

调用fast_hist()函数,遍历测试集/验证集中的所有样本,统计总的分类统计表和loss。

[python] view plain copy
print ?
  1. def compute_hist(net, save_dir, dataset, layer=‘score’, gt=‘label’):  
  2.     n_cl = net.blobs[layer].channels   #score层的特征图数目(也即类别数,例如VOC数据集,这里就为21)  
  3.     if save_dir:  
  4.         os.mkdir(save_dir)   #创建目录  
  5.     hist = np.zeros((n_cl, n_cl)) #创建一个n_cl×n_cl大小的零矩阵,用于存储分类统计表  
  6.     loss = 0 #初始化损失  
  7.     #循环统计每一张测试/验证图片的预测结果,并求和保存在hist中  
  8.     for idx in dataset:  
  9.         net.forward()  
  10.         #net.blobs[gt].data[0, 0]存放着H×W大小的标签图数据  
  11.         #net.blobs[layer].data[0]存放着C×H×W大小的预测图数据(共C张),并通过argmax()获得最终的  
  12.         #H×W大小的预测图数据(数据范围和标签图一致)(argmax本身得到的是最大值处的数组索引号)  
  13.         #flatten()函数平铺整个数组为一行  
  14.         hist += fast_hist(net.blobs[gt].data[00].flatten(),  
  15.                                 net.blobs[layer].data[0].argmax(0).flatten(),  
  16.                                 n_cl)  
  17.   
  18.         if save_dir:  
  19.             #mode=’P’表示产生一张单通道的彩色图  
  20.             #(详见http://pillow.readthedocs.io/en/3.1.x/handbook/concepts.html#concept-modes  
  21.             im = Image.fromarray(net.blobs[layer].data[0].argmax(0).astype(np.uint8), mode=‘P’)  
  22.             im.save(os.path.join(save_dir, idx + ’.png’))  
  23.         # compute the loss as well  
  24.         loss += net.blobs[’loss’].data.flat[0]  #累加每张测试图的loss  
  25.     return hist, loss / len(dataset)  
def compute_hist(net, save_dir, dataset, layer='score', gt='label'): 
n_cl = net.blobs[layer].channels #score层的特征图数目(也即类别数,例如VOC数据集,这里就为21)
if save_dir:
os.mkdir(save_dir) #创建目录
hist = np.zeros((n_cl, n_cl)) #创建一个n_cl×n_cl大小的零矩阵,用于存储分类统计表
loss = 0 #初始化损失
#循环统计每一张测试/验证图片的预测结果,并求和保存在hist中
for idx in dataset:
net.forward()
#net.blobs[gt].data[0, 0]存放着H×W大小的标签图数据
#net.blobs[layer].data[0]存放着C×H×W大小的预测图数据(共C张),并通过argmax()获得最终的
#H×W大小的预测图数据(数据范围和标签图一致)(argmax本身得到的是最大值处的数组索引号)
#flatten()函数平铺整个数组为一行
hist += fast_hist(net.blobs[gt].data[0, 0].flatten(),
net.blobs[layer].data[0].argmax(0).flatten(),
n_cl)
    if save_dir:
        #mode='P'表示产生一张单通道的彩色图
        #(详见http://pillow.readthedocs.io/en/3.1.x/handbook/concepts.html#concept-modes)
        im = Image.fromarray(net.blobs[layer].data[0].argmax(0).astype(np.uint8), mode='P')
        im.save(os.path.join(save_dir, idx + '.png'))
    # compute the loss as well
    loss += net.blobs['loss'].data.flat[0]  #累加每张测试图的loss
return hist, loss / len(dataset)</pre><p>从最后一句&nbsp;return hist, loss / len(dataset)也可以看出,返回的loss是总的loss除以总的样本数。</p><p>还有比较惊喜的一点是这里的im = Image.fromarray(net.blobs[layer].data[0].argmax(0).astype(np.uint8), mode='P')很有启发,这句语句采用mode='P'的形式生成一张单通道的彩色图(实际就是标签图),通过自行编写的代码验证,这个应该就是VOC数据集中产生标签图的一种方法。</p><p>代码如下:</p><div class="dp-highlighter bg_python"><div class="bar"><div class="tools"><b>[python]</b> <a href="#" class="ViewSource" title="view plain" onclick="dp.sh.Toolbar.Command('ViewSource',this);return false;">view plain</a><span data-mod="popu_168"> <a href="#" class="CopyToClipboard" title="copy" onclick="dp.sh.Toolbar.Command('CopyToClipboard',this);return false;">copy</a><div style="position: absolute; left: 307px; top: 3904px; width: 16px; height: 16px; z-index: 99;"><embed id="ZeroClipboardMovie_4" src="https://csdnimg.cn/public/highlighter/ZeroClipboard.swf" loop="false" menu="false" quality="best" bgcolor="#ffffff" name="ZeroClipboardMovie_4" allowscriptaccess="always" allowfullscreen="false" type="application/x-shockwave-flash" pluginspage="http://www.macromedia.com/go/getflashplayer" flashvars="id=4&amp;width=16&amp;height=16" wmode="transparent" width="16" height="16" align="middle"></div><div style="position: absolute; left: 307px; top: 3946px; width: 16px; height: 16px; z-index: 99;"><embed id="ZeroClipboardMovie_10" src="https://csdnimg.cn/public/highlighter/ZeroClipboard.swf" loop="false" menu="false" quality="best" bgcolor="#ffffff" name="ZeroClipboardMovie_10" allowscriptaccess="always" allowfullscreen="false" type="application/x-shockwave-flash" pluginspage="http://www.macromedia.com/go/getflashplayer" flashvars="id=10&amp;width=16&amp;height=16" wmode="transparent" width="16" height="16" align="middle"></div></span><span data-mod="popu_169"> <a href="#" class="PrintSource" title="print" onclick="dp.sh.Toolbar.Command('PrintSource',this);return false;">print</a></span><a href="#" class="About" title="?" onclick="dp.sh.Toolbar.Command('About',this);return false;">?</a></div></div><ol start="1" class="dp-py"><li class="alt"><span><span class="keyword">import</span><span>&nbsp;numpy&nbsp;as&nbsp;np&nbsp;&nbsp;</span></span></li><li class=""><span><span class="keyword">from</span><span>&nbsp;PIL&nbsp;</span><span class="keyword">import</span><span>&nbsp;Image&nbsp;&nbsp;</span></span></li><li class="alt"><span>im&nbsp;=&nbsp;Image.open(<span class="string">'C:/Users/Zheng&nbsp;Chen/Desktop/2007_000392.png'</span><span>)&nbsp;&nbsp;</span></span></li><li class=""><span>in_&nbsp;=&nbsp;np.array(im,&nbsp;dtype=np.uint8)&nbsp;&nbsp;</span></li><li class="alt"><span><span class="keyword">print</span><span>&nbsp;in_&nbsp;&nbsp;</span></span></li><li class=""><span>out&nbsp;=&nbsp;Image.fromarray(in_,&nbsp;mode=<span class="string">'P'</span><span>)&nbsp;&nbsp;</span></span></li><li class="alt"><span>im.save(<span class="string">'C:/Users/Zheng&nbsp;Chen/Desktop/2008.png'</span><span>)&nbsp;&nbsp;</span></span></li></ol></div><pre class="python" name="code" style="display: none;">import numpy as np

from PIL import Image
im = Image.open(‘C:/Users/Zheng Chen/Desktop/2007_000392.png’)
in_ = np.array(im, dtype=np.uint8)
print in_
out = Image.fromarray(in_, mode=’P’)
im.save(‘C:/Users/Zheng Chen/Desktop/2008.png’)

结果如下图(左图为原标签图,右图为上部分代码生成的标签图,可以说是一模一样)

      

              

(3)seg_tests()函数

此函数是score.py文件的入口,即调用此函数便能完成整个测试,并输出四个指标值,此函数的调用可以参见FCN中的solve.py文件中的最后一部分。

[python] view plain copy
print ?
  1. #调用此函数,完成分割测试  
  2. def seg_tests(solver, save_format, dataset, layer=‘score’, gt=‘label’):  
  3.     print ‘>>>’, datetime.now(), ‘Begin seg tests’  
  4.     solver.test_nets[0].share_with(solver.net)  
  5.     do_seg_tests(solver.test_nets[0], solver.iter, save_format, dataset, layer, gt)  
#调用此函数,完成分割测试 
def seg_tests(solver, save_format, dataset, layer='score', gt='label'):
print '>>>', datetime.now(), 'Begin seg tests'
solver.test_nets[0].share_with(solver.net)
do_seg_tests(solver.test_nets[0], solver.iter, save_format, dataset, layer, gt)

(4)do_seg_tests()函数

此函数在compute_hist的基础上,计算出一开始就提出的四个评价指标。

[python] view plain copy
print ?
  1. def do_seg_tests(net, iter, save_format, dataset, layer=‘score’, gt=‘label’):  
  2.     n_cl = net.blobs[layer].channels  #score层的特征图数目(也即类别数,例如VOC数据集,这里就为21)  
  3.     if save_format:  
  4.         save_format = save_format.format(iter)  
  5.     #调用compute_hist统计分割结果  
  6.     hist, loss = compute_hist(net, save_format, dataset, layer, gt)   
  7.     # mean loss 平均loss(即所有测试集总体误差/测试集数目)  
  8.     print ‘>>>’, datetime.now(), ‘Iteration’, iter, ‘loss’, loss  
  9.     # overall accuracy 总体准确度(hist对角线为正确分类结果,其余均为错误分类结果)  
  10.     #np.diag()详见https://docs.scipy.org/doc/numpy/reference/generated/numpy.diag.html  
  11.     acc = np.diag(hist).sum() / hist.sum()  
  12.     print ‘>>>’, datetime.now(), ‘Iteration’, iter, ‘overall accuracy’, acc  
  13.     # per-class accuracy 每一类的准确度  
  14.     acc = np.diag(hist) / hist.sum(1#/为对应位相除;sum(1)为按行求和  
  15.     #输出平均准确度(注意平均准确度与总体准确度但区别)  
  16.     #np.nanmean()详见https://docs.scipy.org/doc/numpy/reference/generated/numpy.nanmean.html  
  17.     print ‘>>>’, datetime.now(), ‘Iteration’, iter, ‘mean accuracy’, np.nanmean(acc)   
  18.     # per-class IU 每一类的交并比  
  19.     #sum(0)为按列求和  
  20.     iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))  
  21.     #输出平均交并比  
  22.     print ‘>>>’, datetime.now(), ‘Iteration’, iter, ‘mean IU’, np.nanmean(iu)  
  23.     freq = hist.sum(1) / hist.sum()  
  24.     #输出频率加权交并比(frequency weighted IU)  
  25.     print ‘>>>’, datetime.now(), ‘Iteration’, iter, ‘fwavacc’, \  
  26.             (freq[freq > 0] * iu[freq > 0]).sum()  
  27.     return hist  
def do_seg_tests(net, iter, save_format, dataset, layer='score', gt='label'): 
n_cl = net.blobs[layer].channels #score层的特征图数目(也即类别数,例如VOC数据集,这里就为21)
if save_format:
save_format = save_format.format(iter)
#调用compute_hist统计分割结果
hist, loss = compute_hist(net, save_format, dataset, layer, gt)
# mean loss 平均loss(即所有测试集总体误差/测试集数目)
print '>>>', datetime.now(), 'Iteration', iter, 'loss', loss
# overall accuracy 总体准确度(hist对角线为正确分类结果,其余均为错误分类结果)
#np.diag()详见https://docs.scipy.org/doc/numpy/reference/generated/numpy.diag.html
acc = np.diag(hist).sum() / hist.sum()
print '>>>', datetime.now(), 'Iteration', iter, 'overall accuracy', acc
# per-class accuracy 每一类的准确度
acc = np.diag(hist) / hist.sum(1) #/为对应位相除;sum(1)为按行求和
#输出平均准确度(注意平均准确度与总体准确度但区别)
#np.nanmean()详见https://docs.scipy.org/doc/numpy/reference/generated/numpy.nanmean.html
print '>>>', datetime.now(), 'Iteration', iter, 'mean accuracy', np.nanmean(acc)
# per-class IU 每一类的交并比
#sum(0)为按列求和
iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
#输出平均交并比
print '>>>', datetime.now(), 'Iteration', iter, 'mean IU', np.nanmean(iu)
freq = hist.sum(1) / hist.sum()
#输出频率加权交并比(frequency weighted IU)
print '>>>', datetime.now(), 'Iteration', iter, 'fwavacc', \
(freq[freq > 0] * iu[freq > 0]).sum()
return hist

在此,介绍一下FCN中用到的四个评价指标(对应FCN论文中的第5部分(Results部分))。

①像素准确度(对应源码解析中的overall accuracy):

http://www.w3.org/1998/Math/MathML&quot; display="block">&#x2211;inii/&#x2211;iti” role=”presentation”>inii/iti∑inii/∑iti

表示属于第i类的所有像素数目

对应do_seg_tests()函数中的源码,相信大家肯定能更好理解和掌握这四个指标的计算技巧。

其中还有一点,就是交并比IU,为所有真实属于第i类的像素点所组成的集合A与所有预测属于第i类的像素点所组成的集合B的交集和并集之比,如下图




    <div class="article-bar-bottom">
                    <div class="tags-box artic-tag-box">
        <span class="label">文章标签:</span>
                    <a class="tag-link" href="http://so.csdn.net/so/search/s.do?q=Deep Learning&amp;t=blog" target="_blank">Deep Learning                       </a>
    </div>
                    <div class="tags-box">
        <span class="label">个人分类:</span>
                    <a class="tag-link" href="https://blog.csdn.net/qq_21368481/article/category/7642210" target="_blank">FCN                       </a>
    </div>
                </div>

<!-- !empty($pre_next_article[0]) -->
  • 1
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值