1.安装
注意,SpikingJelly是基于PyTorch的,需要确保环境中已经安装了PyTorch,才能安装SpikingJelly。
奇数版本是开发版,随着GitHub/OpenI不断更新。偶数版本是稳定版,可以从PyPI获取。
从 PyPI 安装最新的稳定版本(0.0.0.0.6):
pip install spikingjelly
2.构建第一个SNN网络对Mnist数据集分类
引用spikingjelly..clock_driven里的neuron就可以使用脉冲神经元了,这里以LIF为例:
import torch
import torch.nn as nn
from spikingjelly.clock_driven import neuron
class Net(nn.Module):
def __init__(self,tau=100.0,v_threshold=1.0,v_reset=0.0):
super().__init__()
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(28*28, 14*14, bias=False),
neuron.LIFNode(tau=tau, v_threshold=v_threshold, v_reset=v_reset),
nn.Linear(14*14,10,bias=False),
neuron.LIFNode(tau=tau, v_threshold=v_threshold, v_reset=v_reset),
)
def forward(self,x):
return self.fc(x)
也可以直接调用包里已经写好的程序来训练:
import spikingjelly.clock_driven.examples.lif_fc_mnist as lif_fc_mnist
lif_fc_mnist.main()
运行代码后会让你选择数据存储路径和训练的一些参数设置:
之后开始训练:
准确率能到85%: