SNN的一个简单示例

# imports
import snntorch as snn
from snntorch import surrogate
# from snntorch import backprop
from snntorch import functional as SF
from snntorch import utils
# from snntorch import spikeplot as splt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# import torch.nn.functional as F

import matplotlib.pyplot as plt

# import numpy as np
# import itertools

# dataloader arguments
batch_size = 64  # 内存不够128
data_path = 'data'

dtype = torch.float
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cpu")  # 这个电脑的cuda版本太低了

# Define a transform
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize((0,), (1,))])  # 其实这行都没什么用

mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

# Create DataLoaders
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

# neuron and simulation parameters
spike_grad = surrogate.fast_sigmoid(slope=25)  # 替代梯度,用于反向传播
beta = 0.5  # 神经元膜电位的衰减率
num_steps = 50  # 时间步(SNN特有的)

# Define Network
# class Net(nn.Module):
#     def __init__(self):
#         super().__init__()
#
#         # Initialize layers
#         self.conv1 = nn.Conv2d(1, 12, 5)
#         self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)
#         self.conv2 = nn.Conv2d(12, 64, 5)
#         self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)
#         self.fc1 = nn.Linear(64*4*4, 10)
#         self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad)
#
#     def forward(self, x):
#
#         # Initialize hidden states and outputs at t=0
#         mem1 = self.lif1.init_leaky()
#         mem2 = self.lif2.init_leaky()
#         mem3 = self.lif3.init_leaky()
#
#         cur1 = F.max_pool2d(self.conv1(x), 2)
#         spk1, mem1 = self.lif1(cur1, mem1)
#
#         cur2 = F.max_pool2d(self.conv2(spk1), 2)
#         spk2, mem2 = self.lif2(cur2, mem2)
#
#         cur3 = self.fc1(spk2.view(batch_size, -1))
#         spk3, mem3 = self.lif3(cur3, mem3)
#
#         return spk3, mem3

#  Initialize Network
net = nn.Sequential(nn.Conv2d(1, 12, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Conv2d(12, 64, 5),
                    nn.MaxPool2d(2),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
                    nn.Flatten(),
                    nn.Linear(64 * 4 * 4, 10),
                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
                    ).to(device)


def forward_pass(net, num_steps, data):
    mem_rec = []
    spk_rec = []
    utils.reset(net)  # resets hidden states for all LIF neurons in net

    for step in range(num_steps):
        spk_out, mem_out = net(data)
        spk_rec.append(spk_out)
        mem_rec.append(mem_out)
    # 难道说每个step都不会初始化了,我觉得应该是这样
    return torch.stack(spk_rec), torch.stack(mem_rec)


loss_fn = SF.ce_rate_loss()


def batch_accuracy(train_loader, net, num_steps):
    # 训练集一个batch的accuracy
    with torch.no_grad():
        total = 0
        acc = 0
        net.eval()

        train_loader = iter(train_loader)
        for data, targets in train_loader:
            data = data.to(device)
            targets = targets.to(device)
            spk_rec, _ = forward_pass(net, num_steps, data)

            acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
            total += spk_rec.size(1)

    return acc / total


optimizer = torch.optim.Adam(net.parameters(), lr=1e-2, betas=(0.9, 0.999))
num_epochs = 1
loss_hist = []
test_acc_hist = []
counter = 0

# Outer training loop
for epoch in range(num_epochs):  # 只训练1个epoch!?

    # Training loop
    for data, targets in iter(train_loader):
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        net.train()
        #进行num_steps次循环,完成一次前向传播
        spk_rec, _ = forward_pass(net, num_steps, data)

        # initialize the loss & sum over time
        loss_val = loss_fn(spk_rec, targets) #和正确的标签进行计算得到损失

        # Gradient calculation + weight update
        optimizer.zero_grad()
        #根据损失计算梯度
        loss_val.backward()
        #更新梯度
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        if counter % 50 == 0:
            with torch.no_grad():
                net.eval()
                
                # Test set forward pass
                test_acc = batch_accuracy(test_loader, net, num_steps)
                print(f"Iteration {counter}, Test Acc: {test_acc * 100:.2f}%\n")
                test_acc_hist.append(test_acc.item())

        counter += 1

# Plot Loss
fig = plt.figure(facecolor="w")
plt.plot(test_acc_hist)
plt.title("Test Set Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.show()


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值