# -*- coding: utf-8 -*-
"""
Author: Feng
Function: read and analyze caffemodel params
Data: 2018/04/25
"""
import numpy as np
import sys
import os
np.set_printoptions(threshold='nan') # 全部打印输出,不要出现省略号
caffe_python_dir = "./pycaffe"
sys.path.append(caffe_python_dir)
if not os.path.exists(caffe_python_dir):
print("python caffe not found")
exit(-1)
import caffe
TRAIN_TEST_FILE = "./examples/mnist/lenet_train_test.prototxt" # 网络协议文件
MODEL_FILE = "./examples/mnist/lenet_iter_10000.caffemodel" # 训练好的模型文件
if not os.path.exists(TRAIN_TEST_FILE):
print("TRAIN_TEST_FILE not found")
exit(-1)
if not os.path.exists(MODEL_FILE):
print("MODEL_FILE not found")
exit(-1)
net = caffe.Net(TRAIN_TEST_FILE, MODEL_FILE, caffe.TEST) # 加载网络
print(type(net))
# 遍历层获取对应层的名称/权重参数/偏置参数
for param_name in net.params.keys():
# 名称
print(param_name)
# 权重参数
weight = net.params[param_name][0].data
print(type(weight))
print(weight.shape)
print(weight)
#weight.shape = (-1, 1) # 转为单列列表
#print(weight)
# 偏置参数
bias = net.params[param_name][1].data
print(type(bias))
print(bias.shape)
print(bias)
#exit(-1)
参考: http://www.voidcn.com/article/p-vnfdmfrn-mt.html
【辅助脚本】caffemodel参数的读取与分析
最新推荐文章于 2019-04-09 17:34:55 发布