在神经网络上进行数据重建对特征提取和可视化等任务非常有用,查了许多资料有热力图、INN反模型,还有对抗攻击等,都很复杂。这里只用了矩阵逆和svd重建线性层的神经元,可以很好嵌入代码中。特此记录一下
"""
Linear神经元反向数据重建
"""
import torch
import matplotlib.pyplot as plt
N = 10
n = 3
x = torch.randn(N, n)
model = torch.nn.Sequential(
torch.nn.Linear(n, n), torch.nn.Tanh()
)
y = model(x)
z = y
for step in list(model.children())[::-1]:
if isinstance(step, torch.nn.Linear):
z = z - step.bias
w = step.weight
w_size = w.size()
if w_size[0] == w_size[1]:
# 方阵权重数据映射保留在网络中,可以直接取反
t = torch.linalg.inv(w)
z = torch.mm(z, t) # $x = (y-b) @ w^{-1}$
else:
# 非方阵权重数据降维, 需要加入计算数据的均值
mean = torch.mean(x, dim=0)
_, _, vt = torch.linalg.svd(w)
# 选择前w_size[0]个特征向量(降维后的维度)
z = torch.mm(z, vt[:, :w_size[0]].T)
z += mean
elif isinstance(step, torch.nn.Tanh):
z = 0.5 * torch.log((1 + z) / (1 - z))
print('Agreement between x and z: ', torch.dist(x, z, 2))
# 绘制热力图
plt.imshow(z.data.numpy(), cmap='hot', interpoplation='nearest')
plt.colorbar()
plt.show()
# 绘制权重分布图
plt.hist(w.data.numpy().flatten(), bins=50)
plt.show()