SpikingJelly笔记之梯度替代


前言

在SpikingJelly使用梯度替代训练SNN,构建单层全连接SNN实现MNIST分类任务。


一、梯度替代

1、梯度替代:

阶跃函数不可微,无法进行反向传播

g ( x ) = { 1 , x ≥ 0 0 , x < 0 g(x) = \left\{\begin{matrix} 1,&\quad x\ge 0\\ 0,&\quad x<0\\ \end{matrix}\right. g(x)={1,0,x0x<0 , ,\quad\quad\quad g ′ ( x ) = { + ∞ , x = 0 0 , x ≠ 0 g^{\prime}(x) = \left\{\begin{matrix} +∞&,\quad x= 0\\ 0&,\quad x\neq0\\ \end{matrix}\right. g(x)={+0,x=0,x=0

前向传播使用阶跃函数,反向传播使用替代函数

2、梯度替代函数:

来源:spikingjelly.activation_based.surrogate package

①Sigmoid:surrogate.Sigmoid(alpha=4.0, spiking=True)

g ( x ) = s i g m o i d ( α x ) = 1 1 + e − α x g(x) = sigmoid(\alpha x)=\frac{1}{1+e^{-\alpha x}} g(x)=sigmoid(αx)=1+eαx1

g ′ ( x ) = α ∗ s i g m o i d ( α x ) ∗ ( 1 − s i g m o i d ( α x ) ) g^{\prime}(x) = \alpha*sigmoid(\alpha x)*(1-sigmoid(\alpha x)) g(x)=αsigmoid(αx)(1sigmoid(αx))

②ATan:surrogate.ATan(alpha=2.0, spiking=True)

g ( x ) = 1 π a r c t a n ( π 2 α x ) + 1 2 g(x) = \frac{1}{\pi}arctan(\frac{\pi}{2}\alpha x)+\frac{1}{2} g(x)=π1arctan(2παx)+21

g ′ ( x ) = α 2 ( 1 + ( π 2 α x ) 2 ) g^{\prime}(x) = \frac{\alpha}{2(1+(\frac{\pi}{2}\alpha x)^2)} g(x)=2(1+(2παx)2)α

③SoftSign:surrogate.SoftSign(alpha=2.0, spiking=True)

g ( x ) = 1 2 ( α x 1 + ∣ α x ∣ + 1 ) g(x) = \frac{1}{2}(\frac{\alpha x}{1+|\alpha x|}+1) g(x)=21(1+αxαx+1)

g ′ ( x ) = α 2 ( 1 + ∣ α x ∣ 2 ) g^{\prime}(x) = \frac{\alpha}{2(1+|\alpha x|^2)} g(x)=2(1+αx2)α

④LeakyKReLU:surrogate.LeakyKReLU(spiking=True, leak: float=0.0, k: float=1.0)

g ( x ) = { k ∗ x , x ≥ 0 l e a k ∗ x , x < 0 g(x) = \left\{\begin{matrix} k*x,&\quad x\ge 0\\ leak*x,&\quad x<0\\ \end{matrix}\right. g(x)={kx,leakx,x0x<0 , ,\quad\quad\quad g ′ ( x ) = { k , x ≥ 0 l e a k , x < 0 g^{\prime}(x) = \left\{\begin{matrix} k&,\quad x\ge 0\\ leak&,\quad x<0\\ \end{matrix}\right. g(x)={kleak,x0,x<0

二、网络结构

使用神经元层替代激活函数

1、ANN

nn.Sequential(
    nn.Flatten(),
    nn.Linear(28 * 28, 10, bias=False),
    nn.Softmax()
    )

2、SNN

nn.Sequential(
    layer.Flatten(),
    layer.Linear(28 * 28, 10, bias=False),
    neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan())
    )

三、MNIST分类

1、单步模式

(1)导入库

import time
import numpy as np
from matplotlib import pyplot as plt
import torch
from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader
from spikingjelly.activation_based import neuron, encoding,\
    functional, surrogate, layer, monitor
from spikingjelly import visualizing
from load_mnist import load_mnist

(2)构建数据加载器

将numpy数据封装成DataLoader

使用Pytorch自带的数据集会更方便

def To_loader(x_train, y_train, x_test, y_test, batch_size):
    # 转为张量
    x_train = torch.from_numpy(x_train.astype(np.float32))
    y_train = torch.from_numpy(y_train.astype(np.float32))
    x_test = torch.from_numpy(x_test.astype(np.float32))
    y_test = torch.from_numpy(y_test.astype(np.float32))
    # 数据集封装
    train_dataset = TensorDataset(x_train, y_train)
    test_dataset = TensorDataset(x_test, y_test)
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=batch_size,
                              shuffle=True)
    test_loader = DataLoader(dataset=test_dataset,
                              batch_size=batch_size,
                              shuffle=True)
    return train_dataset, test_dataset, train_loader, test_loader

(3)构建SNN模型

将LIF神经元层当作激活函数使用

使用ATan作为梯度替代函数进行反向传播

class SNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Sequential(
            layer.Linear(784, 10, bias=False),
            neuron.LIFNode(tau=2.0,
                           decay_input=True,
                           v_threshold=1.0,
                           v_reset=0.0,
                           surrogate_function=surrogate.ATan(),
                           step_mode='s',
                           store_v_seq=False)
            )
    def forward(self, x):
        return self.layer(x)

(4)训练参数

使用泊松编码器对输入进行编码

取10000个样本进行训练

epoch_num = 10
batch_size = 256
T = 50
lr = 0.001
encoder = encoding.PoissonEncoder() # 泊松编码器
model = SNN() # 单层SNN
loss_function = nn.MSELoss() # 均方误差
optimizer = optim.Adam(model.parameters(), lr) # Adam优化器
x_train, y_train, x_test, y_test = \
    load_mnist(normalize=True, flatten=False, one_hot_label=True)
train_dataset, test_dataset, train_loader, test_loader =\
    To_loader(x_train[:10000], y_train[:10000], x_test, y_test, batch_size)

(5)迭代训练

①取一段时间的平均发放率作为输出

②损失函数采用交叉熵或均方差,使对应神经元fout→1,其他神经元fout→0

③每批训练后重置网络状态

④每轮训练后测试准确率

start_time = time.time()
loss_train_list = []
acc_train_list = []
acc_test_list = []
for epoch in range(epoch_num):
    print('Epoch:%s'%(epoch+1))
    # 模型训练
    loss_train = 0
    acc_train = 0
    for x, y in train_loader:
        f_out = torch.zeros((y.shape[0], 10)) # 输出频率
        # 前向计算,逐步传播
        for t in range(T):
            encoded_x = encoder(x.reshape(-1, 784))
            f_out += model(encoded_x)
        f_out /= T
        # 反向传播
        loss = loss_function(f_out, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 计算损失值与准确率
        loss_train += loss.item()
        acc_train += (f_out.argmax(1) == y.argmax(1)).sum().item()
        # 清除状态
        functional.reset_net(model)
    acc_train /= len(train_dataset)
    loss_train_list.append(loss_train)
    acc_train_list.append(acc_train)
    print('loss_train:', loss_train)
    print('acc_train:{:.2%}:'.format(acc_train))
    # 模型测试
    with torch.no_grad():
        acc_test = 0
        for x, y in test_loader:
            f_out = torch.zeros((y.shape[0], 10))
            # 逐步传播
            for t in range(T):
                encoded_x = encoder(x.reshape(-1,784))
                f_out += model(encoded_x)
            f_out /= T
            loss = loss_function(f_out, y)
            acc_test += (f_out.argmax(1) == y.argmax(1)).sum().item()
            functional.reset_net(model)
        acc_test /= len(test_dataset)
        acc_test_list.append(acc_test)
        print('acc_test:{:.2%}'.format(acc_test))
end_time = time.time()
print('Time:{:.1f}s'.format(end_time - start_time))

训练结果:

Epoch:10
loss_train: 0.8223596904426813
acc_train:91.10%
acc_test:90.24%
Time:123.3s

(6)显示损失值与准确率变化

fig1 = plt.figure(1, figsize=(12, 6))
ax1 = fig1.add_subplot(2, 2, 1)
ax1.plot(loss_train_list, 'r-')
ax1.set_title('loss_train')
ax2 = fig1.add_subplot(2, 2, 2)
ax2.plot(acc_train_list, 'b-')
ax2.set_title('acc_train')
ax3 = fig1.add_subplot(2, 1, 2)
ax3.plot(acc_test_list, 'b-')
ax3.set_title('acc_test')
plt.show()

训练结果:

(7)结果预测

选取一个数据,观察各神经元的膜电位变化与输出情况

# 设置监视器
for m in model.modules():
    if isinstance(m, neuron.LIFNode):
        m.store_v_seq = True
monitor_o = monitor.OutputMonitor(model, neuron.LIFNode)
monitor_v = monitor.AttributeMonitor('v',
                                      pre_forward=False,
                                      net=model,
                                      instance=neuron.LIFNode)
print('model:', model)
print('monitor_v:', monitor_v.monitored_layers)
print('monitor_o:', monitor_o.monitored_layers)
# 选择一组输入
x, y = test_dataset[0]
f_out = torch.zeros((y.shape[0], 10))
with torch.no_grad():
    # 逐步传播
    for t in range(T):
        encoded_x = encoder(x.reshape(-1,784))
        f_out += model(encoded_x)
    functional.reset_net(model)
    label = y.argmax().item()
    pred = f_out.argmax().item()
print('label:{},predict:{}'.format(label, pred))
# 膜电位与输出可视化
# 膜电位变化
dpi = 100
figsize = (6, 4)
# 合并列表中的张量,删除多余维度,删除梯度信息
v_list = torch.stack(monitor_v['layer.1']).squeeze().detach()
visualizing.plot_2d_heatmap(array=v_list.numpy(),
                            title='Membrane Potentials',
                            xlabel='Simulating Step',
                            ylabel='Neuron Index',
                            int_x_ticks=True,
                            x_max=T,
                            figsize=figsize,
                            dpi=dpi)
# 神经元输出
s_list = torch.stack(monitor_o['layer.1']).squeeze().detach()
visualizing.plot_1d_spikes(spikes=s_list.numpy(),
                            title='Out Spikes',
                            xlabel='Simulating Step',
                            ylabel='Neuron Index',
                            figsize=figsize,
                            dpi=dpi)

预测结果:

model: SNN(
  (layer): Sequential(
    (0): Linear(in_features=784, out_features=10, bias=False)
    (1): LIFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=s, backend=torch, tau=2.0
      (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
  )
)
monitor_v: ['layer.1']
monitor_o: ['layer.1']
label:7,predict:7

膜电位变化:

神经元输出:

2、多步模式

将单步模式改为多步模式,需要修改以下部分:

(1)将神经元层的步进模式由’s’改为’m’

neuron.LIFNode(tau=2.0,
               decay_input=True,
               v_threshold=1.0,
               v_reset=0.0,
               surrogate_function=surrogate.ATan(),
               step_mode='m',
               store_v_seq=False)

(2)一次将所有时间步的数据全部输入

            encoded_x = encoder(x).repeat(T,1,1))
            f_out += model(encoded_x).sum(axis=0)
            f_out /= T

(3)修改监视器监视的变量

monitor_v = monitor.AttributeMonitor('v_seq',
                                      pre_forward=False,
                                      net=model,
                                      instance=neuron.LIFNode)

输出情况:

①训练结果

Epoch:10
loss_train: 0.8167978068813682
acc_train:91.06%:
acc_test:89.78%:
Time:145.1s

②网络结构

model: SNN(
  (layer): Sequential(
    (0): Linear(in_features=784, out_features=10, bias=False)
    (1): LIFNode(
      v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=torch, tau=2.0
      (surrogate_function): ATan(alpha=2.0, spiking=True)
    )
  )
)
monitor_v: ['layer.1']
monitor_o: ['layer.1']
label:7,predict:7

③膜电位变化

④神经元输出:


总结

使用梯度替代法进行反向传播时,使用可微的激活函数替代,避免脉冲的不可微;

使用编码器将输入编码为1/0脉冲序列;

将神经元层代替激活函数;

“在正确构建网络的情况下,逐层传播的并行度更大,速度更快”。但在此逐步传播比逐层传播略快一些。

  • 22
    点赞
  • 19
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值