参考:https://blog.csdn.net/leviopku/article/details/78510977
最近需要把resnet中的参数(权重,偏置)从网络中取出来分析,网上资料很少,怼了一下午算是成功了,记录一下。
选择的框架是TF,因为pytorch提出来的参数数据结构复杂,而TF可以直接变为numpy。
思路就是先把所有参数用tf.train.NewCheckpointReader 和 get_variable_to_shape_map()
变为字典。Key就是网络结构的名字,然后用一个循环,提取每层的参数,并reshape为行向量(因为维度变化大,所以要先求参数量),然后拼接到初始的矩阵中。
最后初始矩阵就会变为结果矩阵。
代码如下:
代码中未显示全的是tf模型ckpt文件的路径。