import snntorch as snn
from snntorch import surrogate
from snntorch import functional as SF
from snntorch import utils
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
batch_size = 64
data_path = 'data'
dtype = torch.float
device = torch.device("cpu")
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.Grayscale(),
transforms.ToTensor(),
transforms.Normalize((0,), (1,))])
mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)
spike_grad = surrogate.fast_sigmoid(slope=25)
beta = 0.5
num_steps = 50
net = nn.Sequential(nn.Conv2d(1, 12, 5),
nn.MaxPool2d(2),
snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
nn.Conv2d(12, 64, 5),
nn.MaxPool2d(2),
snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),
nn.Flatten(),
nn.Linear(64 * 4 * 4, 10),
snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)
).to(device)
def forward_pass(net, num_steps, data):
mem_rec = []
spk_rec = []
utils.reset(net)
for step in range(num_steps):
spk_out, mem_out = net(data)
spk_rec.append(spk_out)
mem_rec.append(mem_out)
return torch.stack(spk_rec), torch.stack(mem_rec)
loss_fn = SF.ce_rate_loss()
def batch_accuracy(train_loader, net, num_steps):
with torch.no_grad():
total = 0
acc = 0
net.eval()
train_loader = iter(train_loader)
for data, targets in train_loader:
data = data.to(device)
targets = targets.to(device)
spk_rec, _ = forward_pass(net, num_steps, data)
acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)
total += spk_rec.size(1)
return acc / total
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2, betas=(0.9, 0.999))
num_epochs = 1
loss_hist = []
test_acc_hist = []
counter = 0
for epoch in range(num_epochs):
for data, targets in iter(train_loader):
data = data.to(device)
targets = targets.to(device)
net.train()
spk_rec, _ = forward_pass(net, num_steps, data)
loss_val = loss_fn(spk_rec, targets)
optimizer.zero_grad()
loss_val.backward()
optimizer.step()
loss_hist.append(loss_val.item())
if counter % 50 == 0:
with torch.no_grad():
net.eval()
test_acc = batch_accuracy(test_loader, net, num_steps)
print(f"Iteration {counter}, Test Acc: {test_acc * 100:.2f}%\n")
test_acc_hist.append(test_acc.item())
counter += 1
fig = plt.figure(facecolor="w")
plt.plot(test_acc_hist)
plt.title("Test Set Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.show()