[KAN学习日记](3)使用KAN进行持续学习

教程地址:https://github.com/KindXiaoming/pykan/blob/master/README.md
本系列文章用于记录自己学习过程

在本教程中,我们的目标是通过样本学习一个一维函数。该一维函数包含5个高斯峰。不同于将所有样本一次性呈现给神经网络,我们将进行五个学习阶段。在每个阶段中,仅向KAN(核逼近网络)呈现一个峰值周围的样本。我们发现,由于样条的局部性,KAN能够进行持续学习。
(1)环境设置和数据准备

from kan import *
import numpy as np
import torch
import matplotlib.pyplot as plt
datasets = []

n_peak = 5
n_num_per_peak = 100
n_sample = n_peak * n_num_per_peak

x_grid = torch.linspace(-1, 1, steps=n_sample)

x_centers = 2/n_peak * (np.arange(n_peak) - n_peak/2+0.5)

x_sample = torch.stack([torch.linspace(-1/n_peak, 1/n_peak, steps=n_num_per_peak) + center for center in x_centers]).reshape(-1,)

y = 0.
for center in x_centers:
    y += torch.exp(-(x_grid - center) ** 2 * 300)
    
y_sample = 0.
for center in x_centers:
    y_sample += torch.exp(-(x_sample - center) ** 2 * 300)


通过绘制这些数据,我们可以直观地看到高斯峰及其样本分布。

plt.plot(x_grid.detach().numpy(), y.detach().numpy())
plt.scatter(x_sample.detach().numpy(), y_sample.detach().numpy())
plt.show()

请添加图片描述
(2)分阶段呈现不同的峰值给KAN
我们分五个阶段,将每个峰值周围的样本逐步呈现给KAN。

plt.subplots(1, 5, figsize=(15, 2))
plt.subplots_adjust(wspace=0, hspace=0)

for i in range(1, 6):
    plt.subplot(1, 5, i)
    group_id = i - 1
    plt.plot(x_grid.detach().numpy(), y.detach().numpy(), color='black', alpha=0.1)
    plt.scatter(x_sample[group_id * n_num_per_peak:(group_id + 1) * n_num_per_peak].detach().numpy(), y_sample[group_id * n_num_per_peak:(group_id + 1) * n_num_per_peak].detach().numpy(), color="black", s=2)
    plt.xlim(-1, 1)
    plt.ylim(-1, 2)
plt.show()

请添加图片描述
(3)训练KAN模型
我们为每个阶段分别训练KAN模型。

ys = []

model = KAN(width=[1, 1], grid=200, k=3, noise_scale=0.1, bias_trainable=False, sp_trainable=False, sb_trainable=False)

for group_id in range(n_peak):
    dataset = {}
    dataset['train_input'] = x_sample[group_id * n_num_per_peak:(group_id + 1) * n_num_per_peak][:, None]
    dataset['train_label'] = y_sample[group_id * n_num_per_peak:(group_id + 1) * n_num_per_peak][:, None]
    dataset['test_input'] = x_sample[group_id * n_num_per_peak:(group_id + 1) * n_num_per_peak][:, None]
    dataset['test_label'] = y_sample[group_id * n_num_per_peak:(group_id + 1) * n_num_per_peak][:, None]
    model.train(dataset, opt='LBFGS', steps=100, update_grid=False)
    y_pred = model(x_grid[:, None])
    ys.append(y_pred.detach().numpy()[:, 0])

(4)每个阶段之后KAN的预测
我们绘制KAN在每个阶段之后的预测结果,以观察其学习过程。

plt.subplots(1, 5, figsize=(15, 2))
plt.subplots_adjust(wspace=0, hspace=0)

for i in range(1, 6):
    plt.subplot(1, 5, i)
    group_id = i - 1
    plt.plot(x_grid.detach().numpy(), y.detach().numpy(), color='black', alpha=0.1)
    plt.plot(x_grid.detach().numpy(), ys[i - 1], color='black')
    plt.xlim(-1, 1)
    plt.ylim(-1, 2)
plt.show()

请添加图片描述
通过以上步骤,我们成功实现了使用KAN进行的持续学习。在每个阶段,模型仅学习当前阶段的数据,但最终能够很好地拟合整个函数。这展示了KAN在处理持续学习任务中的有效性。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值