caffe画 feature map

import numpy as np
import matplotlib.pyplot as plt
import sys
import os
import caffe
plt.rcParams['figure.figsize'] = (10, 10)        # large images
plt.rcParams['image.interpolation'] = 'nearest'  # don't interpolate: show square pixels
plt.rcParams['image.cmap'] = 'gray'  # use grayscale output rather than a (potentially misleading) color heatmap


caffe_root = 'D:/*/caffe-master/'  # this file should be run from {caffe_root}/examples (otherwise change this line)
model_root = 'D:/RFCN/'


sys.path.insert(0, caffe_root + 'python')


caffe.set_mode_cpu()


model_def = model.prototxt'
model_weights = model_root + 'resnet50_iter_110000.caffemodel'


net = caffe.Net(model_def,model_weights,caffe.TEST)
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
transformer.set_transpose('data', (2,0,1))
transformer.set_raw_scale('data', 255)
transformer.set_channel_swap('data', (2,1,0))  # swap channels from RGB to BGR
print net.blobs['data'].data.shape
image = caffe.io.load_image('D:/000001.png')
transformed_image = transformer.preprocess('data', image)
plt.imshow(image)
plt.show()


net.blobs['data'].data[...] = transformed_image
output = net.forward()


def vis_square(data):
    """Take an array of shape (n, height, width) or (n, height, width, 3)
       and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)"""


    # normalize data for display
    data = (data - data.min()) / (data.max() - data.min())


    # force the number of filters to be square
    n = int(np.ceil(np.sqrt(data.shape[0])))
    padding = (((0, n ** 2 - data.shape[0]),
                (0, 1), (0, 1))  # add some space between filters
               + ((0, 0),) * (data.ndim - 3))  # don't pad the last dimension (if there is one)
    data = np.pad(data, padding, mode='constant', constant_values=1)  # pad with ones (white)


    # tile the filters into an image
    data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))
    data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])


    plt.imshow(data)
    plt.axis('off')
    plt.show()


filters = net.blobs['ipmdata'].data
print(filters.shape)
vis_square(filters.transpose(0, 2, 3, 1))
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值