使用PyTorch+functorch计算并可视化NTK矩阵

2022年3月,PyTorch发布了PyTorch1.11和functorch。

functorch灵感来自于Google JAX,旨在提供vmap和autodiff转换配合PyTorch使用。

本文将演示如何使用PyTorch和functorch计算并可视化NTK

1. 环境配置

# first install PyTorch 1.11
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch

pip install functorch

2. Setup

2.1 搭建网络

import torch
import torch.nn as nn
from functorch import make_functional, vmap, jacrev
import numpy as np
from matplotlib import pyplot as plt
device = 'cuda'

class NN(nn.Module):
    def __init__(self, layer_sizes):
        super(NN, self).__init__()
        self.linears = nn.ModuleList()
        for i in range(len(layer_sizes) - 1):
            m = nn.Linear(layer_sizes[i], layer_sizes[i + 1])
            self.linears.append(m)
        
    def forward(self, x):
        for linear in self.linears[:-1]:
            x = torch.tanh(linear(x))
        x = self.linears[-1](x)
        return x 

2.2 生成一些数据

x = torch.linspace(0, 1, 100).unsqueeze(-1).to(device)

3. 创建模型的函数版本

为了计算NTK,我们需要一个函数来接受模型的参数和单个输入(而非一批输入),并返回单个输出。可以使用functorch的make_functional完成这一步。

layer_sizes = [1] + [100] * 3 + [1]
net = NN(layer_sizes).to(device)

fnet, params = make_functional(net)
# if your net has buffers
# fnet, params, buffers = make_functional_with_buffers(net)

生成一个在单个数据点上评估模型的函数

def fnet_single(params, x):
    return fnet(params, x.unsqueeze(0)).squeeze(0)

4. 计算NTK

def empirical_ntk(fnet_single, params, x1, x2, compute='trace'):
    # Compute J(x1)
    jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
    jac1 = [j.flatten(2) for j in jac1]
    
    # Compute J(x2)
    jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
    jac2 = [j.flatten(2) for j in jac2]
    
    # Compute J(x1) @ J(x2).T
    einsum_expr = None
    if compute == 'full':
        einsum_expr = 'Naf,Mbf->NMab'
    elif compute == 'trace':
        einsum_expr = 'Naf,Maf->NM'
    elif compute == 'diagonal':
        einsum_expr = 'Naf,Maf->NMa'
    else:
        assert False
        
    result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])
    result = result.sum(0)
    return result

ntk_result = empirical_ntk(fnet_single, params, x, x, 'trace')
print(ntk_result.shape)

torch.Size([100, 100])

5. 计算NTK矩阵特征值和特征向量

# Compute eigenvalues
lambda_K, eigvec_K = np.linalg.eig(ntk_result.detach().cpu().numpy())

# Sort in descresing order
lambda_K = np.sort(np.real(lambda_K))[::-1]

6. 可视化

# Visualize the eigenvectors of the NTK
fig, axs= plt.subplots(2, 3, figsize=(12, 6))
X = np.linspace(0, 1, len(x))
axs[0, 0].plot(X, np.real(eigvec_K[:,0]))
axs[0, 1].plot(X, np.real(eigvec_K[:,1]))
axs[0, 2].plot(X, np.real(eigvec_K[:,2]))
axs[1, 0].plot(X, np.real(eigvec_K[:,3]))
axs[1, 1].plot(X, np.real(eigvec_K[:,4]))
axs[1, 2].plot(X, np.real(eigvec_K[:,5]))
plt.show()

# Visualize the eigenvalues of the NTK
fig, ax = plt.subplots(figsize=(6, 5))
ax.plot(lambda_K)
plt.xscale('log')
plt.yscale('log')
ax.set_xlabel('index')
ax.set_ylabel(r'$\lambda$') 
plt.show()

欢迎交流讨论:kjzxcsq@qq.com

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
PyTorch是一种广泛应用于机器学习和深度学习的开源框架,它提供了丰富的工具和算法来构建和训练神经网络模型。而Milvus是一个高效的向量相似度搜索引擎,主要用于大规模向量数据的存储、管理和检索。 PyTorch与Milvus可以相互配合使用,以实现更加高效的机器学习和深度学习任务。PyTorch提供了强大的计算能力和灵活性,可以用于训练各种类型的神经网络模型。而Milvus则提供了快速的向量相似度搜索功能,可以将训练好的模型转化为向量表示,并高效地进行相似度匹配。 在使用PyTorch和Milvus的组合时,首先我们可以使用PyTorch来训练神经网络模型,并从中提取特征向量。然后,将这些特征向量通过Milvus进行存储和索引。Milvus提供了高效的索引结构和查询算法,可以快速地搜索和检索与查询向量最相似的向量,从而实现向量的相似度匹配和搜索。 使用PyTorch和Milvus的组合,可以在机器学习和深度学习任务中提高效率和准确性。通过PyTorch进行模型训练和特征提取,再通过Milvus进行向量索引和相似度搜索,可以在大规模数据集上快速地找到与查询向量最相似的向量,从而实现更加高效和灵活的机器学习和深度学习应用。 总而言之,PyTorch和Milvus是两个强大的工具,它们的组合可以在机器学习和深度学习任务中发挥协同作用,提供高效的模型训练和向量相似度搜索能力。这对于处理大规模数据和提高机器学习任务效果具有重要意义。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值