一、关键代码与效果展示
代码:
from matplotlib import pyplot as plt
def plot_weights(model):
modules = [module for module in model.modules()]
num_sub_plot = 0
for i, layer in enumerate(modules):
if hasattr(layer, 'weight'):
plt.subplot(221+num_sub_plot)
w = layer.weight.data
w_one_dim = w.cpu().numpy().flatten()
plt.hist(w_one_dim, bins=50)
num_sub_plot += 1
plt.show()
plot_weights(model)
plot_weights(quant_model)
效果:
可以看到上图所示代码可视化了未量化前的网络权重参数和量化后的网络权重参数,总体可分为两步进行,第一步是读取网络权重参数,第二步是可视化权重参数。
二、读取网络权重参数
代码:
modules = [module for module in model.modules()]
print("="*10)
print("网络结构如下所示:")
for i, layer in enumerate(modules):
print(i,layer)
print("="*10)
print("其中具有权重参数的层为:")
for i, layer in enumerate(modules):
if hasattr(layer, 'weight'):
print(i,layer)
#print(layer.weight)
#print(layer.weight.data)
思路是借用 nn.model.modules() 函数获取网络各层,然后再迭代判断哪些层具有 weight
属性,即存在权重参数。
结果:
可以看到只有卷积层和线性层具有权重参数,该网络一共是有四个层具有权重参数。
三、可视化权重参数
所以要绘制四个子图,用 matplotlib.pyplot.hist
函数即可完成直方图的绘制,思路就是迭代读取网络各层权重参数,然后每读取一层具有权重参数的网络就绘制对应的频数直方图。