(五)使用snntorch训练脉冲神经网络


$ pip install snntorch
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np
import itertools

1.脉冲神经网络的递归表示

在这里插入图片描述
脉冲用下式表示,如果膜电位超过阈值,就会发出一个脉冲:
在这里插入图片描述

2.脉冲的不可微性

在这里插入图片描述
在阈值处,导数趋于无穷大,梯度总是归零,无法进行学习,这被称为死神经元问题。

2.2 克服死神经元问题

*替代梯度法
*在这里插入图片描述
调用 snn.Leaky 神经元也能实现同样的效果。 事实上,每次从 snnTorch 调用任何神经元模型时, ATan 替代梯度都会默认应用于该神经元:

lif1 = snn.Leaky(beta=0.9)

3. 通过时间反向传播

4. 设置输出函数、输出解码

在传统的非脉冲神经网络中,有监督的多类分类问题会选取 激活度最高的神经元,并将其作为预测类别。
在脉冲神经网络中,有多种解释输出脉冲的方式。最常见的方法包括:

脉冲率编码:选择具有最高脉冲率(或脉冲计数)的神经元作为预测类别
延迟编码:选择首先发放脉冲的神经元作为预测类别

这可能会让联想到教程(一)脉冲编码。不同之处在于,在这里,我们是在解释(解码)输出脉冲,而不是将原始输入数据编码/转换成脉冲。

让我们专注于脉冲率编码。当输入数据传递到网络时, 我们希望正确的神经元类别在仿真运行的过程中发射最多的脉冲。 这对应于最高的平均脉冲频率。实现这一目标的一种方法是增加正确类别的膜电位至 U>Uthr, 并将不正确类别的膜电位设置为 U<Uthr。

这可以通过对输出神经元的膜电位取softmax来实现,其中C是输出类别的数量:
在这里插入图片描述
通过以下方式获取pi和目标 yi 之间的交叉熵, 目标是一个独热(one-hot)目标向量:
在这里插入图片描述
实际效果是,鼓励正确类别的膜电位增加,而不正确类别的膜电位降低。 这意味着在所有时间步中鼓励正确类别激活,且在所有时间步中抑制不正确类别。 这可能不是脉冲神经网络的最高效实现之一,但它是其中最简单的之一。
这个目标应用于仿真的每个时间步,因此也在每个步骤生成一个损失。 然后在仿真结束时将这些损失相加:
在这里插入图片描述

5.配置静态MNIST数据集

# dataloader 
batch_size = 128
data_path='/tmp/data/mnist'

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
# transform
transform = transforms.Compose([
            transforms.Resize((28, 28)),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0,), (1,))])

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)
# 创建 DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

6. 定义网络

# 网络框架
num_inputs = 28*28
num_hidden = 1000
num_outputs = 10

# Temporal Dynamics
num_steps = 25
beta = 0.95
# 定义网络
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        # Initialize layers
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta)
        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta)

    def forward(self, x):

        # Initialize hidden states at t=0
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        for step in range(num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

# Load the network onto CUDA if available
net = Net().to(device)

fc1 对来自MNIST数据集的所有输入像素应用线性变换;
lif1 集成了随时间变化的加权输入,如果满足阈值条件,则发放脉冲;
fc2 对 lif1 的输出脉冲应用线性变换;
lif2 是另一层脉冲神经元,集成了随时间变化的加权脉冲。

7.训练SNN

7.1 准确率指标

下面这个函数会获取一批数据、统计每个神经元的所有脉冲(即模拟时间内的脉冲率代码), 并将最高计数的索引与实际目标进行比较。如果两者匹配,则说明网络正确预测了目标。


def print_batch_accuracy(data, targets, train=False):
    output, _ = net(data.view(batch_size, -1))
    _, idx = output.sum(dim=0).max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def train_printer():
    print(f"Epoch {epoch}, Iteration {iter_counter}")
    print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    print_batch_accuracy(data, targets, train=True)
    print_batch_accuracy(test_data, test_targets, train=False)
    print("\n")

7.2 损失函数定义

loss = nn.CrossEntropyLoss()

7.3 优化器

Adam 是一个稳健的优化器,在递归网络中表现出色, 因此我们应用Adam并将其学习率为5e-4

optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))

7.4 一次训练迭代

data, targets = next(iter(train_loader))#每次迭代都会获取一批新的数据和目标
data = data.to(device)
targets = targets.to(device)

将输入数据拍扁为大小为 784 的向量,并将其传入网络。

spk_rec, mem_rec = net(data.view(batch_size, -1))
print(mem_rec.size())
torch.Size([25, 128, 10])#25个时间步长,128个数据样本,10个输出神经元

计算出每个时间步的损失,并将这些损失相加:

# initialize the total loss value
loss_val = torch.zeros((1), dtype=dtype, device=device)

# sum loss at every step
for step in range(num_steps):
  loss_val += loss(mem_rec[step], targets)

对网络进行一次权重更新:

# 清除之前的梯度
optimizer.zero_grad()

# 计算梯度
loss_val.backward()

# 权重优化
optimizer.step()

现在,在一次迭代后重新运行损失计算和精度:

# calculate new network outputs using the same data
spk_rec, mem_rec = net(data.view(batch_size, -1))

# initialize the total loss value
loss_val = torch.zeros((1), dtype=dtype, device=device)

# sum loss at every step
for step in range(num_steps):
  loss_val += loss(mem_rec[step], targets)

只经过一次迭代

7.5 训练循环

num_epochs = 1
loss_hist = []
test_loss_hist = []
counter = 0

# Outer training loop
for epoch in range(num_epochs):
    iter_counter = 0
    train_batch = iter(train_loader)

    # Minibatch training loop
    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        net.train()
        spk_rec, mem_rec = net(data.view(batch_size, -1))

        # initialize the loss & sum over time
        loss_val = torch.zeros((1), dtype=dtype, device=device)
        for step in range(num_steps):
            loss_val += loss(mem_rec[step], targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        with torch.no_grad():
            net.eval()
            test_data, test_targets = next(iter(test_loader))
            test_data = test_data.to(device)
            test_targets = test_targets.to(device)

            # Test set forward pass
            test_spk, test_mem = net(test_data.view(batch_size, -1))

            # Test set loss
            test_loss = torch.zeros((1), dtype=dtype, device=device)
            for step in range(num_steps):
                test_loss += loss(test_mem[step], test_targets)
            test_loss_hist.append(test_loss.item())

            # Print train/test loss/accuracy
            if counter % 50 == 0:
                train_printer()
            counter += 1
            iter_counter +=1

终端每迭代 50 次就会打印出类似的内容:

Epoch 0, Iteration 50
Train Set Loss: 12.63
Test Set Loss: 13.44
Train set accuracy for a single minibatch: 92.97%
Test set accuracy for a single minibatch: 90.62%

8. 结果

8.1 训练、测试损失

fig = plt.figure(facecolor="w", figsize=(10, 5))
plt.plot(loss_hist)
plt.plot(test_loss_hist)
plt.title("Loss Curves")
plt.legend(["Train Loss", "Test Loss"])
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.show()

在这里插入图片描述

8.2 测试集准确率

### 脉冲神经网络用于图像分割 脉冲神经网络(SNNs)模仿生物神经系统的工作方式,在处理时空数据方面具有独特优势。对于图像分割任务,SNN可以捕捉像素间的动态关系并实现高效计算。 #### SNN架构设计 为了适应图像分割需求,通常采用卷积层来提取特征图谱。不同于传统人工神经元模型,SNN中的神经元通过发放尖峰信号传递信息[^1]。这种机制使得网络能够更好地模拟真实大脑活动模式,并可能带来更低能耗与更优性能表现。 #### 实现流程概述 构建基于SNN的图像分割系统涉及以下几个关键技术环节: - **输入编码**: 将灰度级或RGB颜色值转换成时间序列形式表示; - **前向传播算法**: 利用膜电位更新规则完成逐层激活过程; - **监督学习方法**: 应用STDP(突触可塑性原理)或其他适合于离散事件驱动框架下的训练策略调整权重参数; ```python import torch from snntorch import spikegen, surrogate # 定义超参数 num_steps = 50 beta = 0.9 # 创建一个简单的两层SNN分类器 class Net(torch.nn.Module): def __init__(self): super().__init__() self.fc1 = torch.nn.Linear(784, 500) self.lif1 = surrogate.Lapicque(beta=beta) def forward(self, x): spk_rec = [] mem_rec = [] for step in range(num_steps): cur = self.fc1(x) spk, mem = self.lif1(cur) spk_rec.append(spk) mem_rec.append(mem) return torch.stack(spk_rec), torch.stack(mem_rec) net = Net() data_loader = ... # 加载MNIST等公开可用的数据集作为测试案例 for data, targets in data_loader: spikes = spikegen.rate(data.view(-1, 784)) spk_out, _ = net(spikes) ``` 上述代码片段展示了如何利用`snntorch`库快速搭建起基础版的SNN结构来进行简单二值化图片识别实验。实际应用到复杂场景比如医学影像分析时还需要进一步优化网络拓扑以及探索更适合的任务导向型损失函数设计思路。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值