spikingjelly学习-使用单层全连接snn脉冲神经网络识别mnist数据集

parser.add_argument('-data-dir', type=str, help='root dir of MNIST dataset')
parser.add_argument('-out-dir', type=str, default='./logs', help='root dir for saving logs and checkpoint')
parser.add_argument('-resume', type=str, help='resume from the checkpoint path')
parser.add_argument('-amp', action='store\_true', help='automatic mixed precision training')
parser.add_argument('-opt', type=str, choices=['sgd', 'adam'], default='adam', help='use which optimizer. SGD or Adam')
parser.add_argument('-momentum', default=0.9, type=float, help='momentum for SGD')
parser.add_argument('-lr', default=1e-3, type=float, help='learning rate')
parser.add_argument('-tau', default=2.0, type=float, help='parameter tau of LIF neuron')

args = parser.parse_args()
print(args)

net = SNN(tau=args.tau)

print(net)

net.to(args.device)

# 初始化数据加载器
train_dataset = torchvision.datasets.MNIST(
    root=args.data_dir,
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True
)
test_dataset = torchvision.datasets.MNIST(
    root=args.data_dir,
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

train_data_loader = data.DataLoader(
    dataset=train_dataset,
    batch_size=args.b,
    shuffle=True,
    drop_last=True,
    num_workers=args.j,
    pin_memory=True
)
test_data_loader = data.DataLoader(
    dataset=test_dataset,
    batch_size=args.b,
    shuffle=False,
    drop_last=False,
    num_workers=args.j,
    pin_memory=True
)

scaler = None
if args.amp:
    scaler = amp.GradScaler()

start_epoch = 0
max_test_acc = -1

optimizer = None
if args.opt == 'sgd':
    optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum)
elif args.opt == 'adam':
    optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
else:
    raise NotImplementedError(args.opt)

if args.resume:
    checkpoint = torch.load(args.resume, map_location='cpu')
    net.load_state_dict(checkpoint['net'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    start_epoch = checkpoint['epoch'] + 1
    max_test_acc = checkpoint['max\_test\_acc']

out_dir = os.path.join(args.out_dir, f'T{args.T}\_b{args.b}\_{args.opt}\_lr{args.lr}')

if args.amp:
    out_dir += '\_amp'

if not os.path.exists(out_dir):
    os.makedirs(out_dir)
    print(f'Mkdir {out\_dir}.')

with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt:
    args_txt.write(str(args))

writer = SummaryWriter(out_dir, purge_step=start_epoch)
with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt:
    args_txt.write(str(args))
    args_txt.write('\n')
    args_txt.write(' '.join(sys.argv))

encoder = encoding.PoissonEncoder()

for epoch in range(start_epoch, args.epochs):
    start_time = time.time()
    net.train()
    train_loss = 0
    train_acc = 0
    train_samples = 0
    for img, label in train_data_loader:
        optimizer.zero_grad()
        img = img.to(args.device)
        label = label.to(args.device)
        label_onehot = F.one_hot(label, 10).float()

        if scaler is not None:
            with amp.autocast():
                out_fr = 0.
                for t in range(args.T):
                    encoded_img = encoder(img)
                    out_fr += net(encoded_img)
                out_fr = out_fr / args.T
                loss = F.mse_loss(out_fr, label_onehot)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            out_fr = 0.
            for t in range(args.T):
                encoded_img = encoder(img)
                out_fr += net(encoded_img)
            out_fr = out_fr / args.T
            loss = F.mse_loss(out_fr, label_onehot)
            loss.backward()
            optimizer.step()

        train_samples += label.numel()
        train_loss += loss.item() \* label.numel()
        train_acc += (out_fr.argmax(1) == label).float().sum().item()

        functional.reset_net(net)

    train_time = time.time()
    train_speed = train_samples / (train_time - start_time)
    train_loss /= train_samples
    train_acc /= train_samples

    writer.add_scalar('train\_loss', train_loss, epoch)
    writer.add_scalar('train\_acc', train_acc, epoch)

    net.eval()
    test_loss = 0
    test_acc = 0
    test_samples = 0
    with torch.no_grad():
        for img, label in test_data_loader:
            img = img.to(args.device)
            label = label.to(args.device)
            label_onehot = F.one_hot(label, 10).float()
            out_fr = 0.
            for t in range(args.T):
                encoded_img = encoder(img)
                out_fr += net(encoded_img)
            out_fr = out_fr / args.T
            loss = F.mse_loss(out_fr, label_onehot)

            test_samples += label.numel()
            test_loss += loss.item() \* label.numel()
            test_acc += (out_fr.argmax(1) == label).float().sum().item()
            functional.reset_net(net)
    test_time = time.time()
    test_speed = test_samples / (test_time - train_time)
    test_loss /= test_samples
    test_acc /= test_samples
    writer.add_scalar('test\_loss', test_loss, epoch)
    writer.add_scalar('test\_acc', test_acc, epoch)

    save_max = False
    if test_acc > max_test_acc:
        max_test_acc = test_acc
        save_max = True

    checkpoint = {
        'net': net.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
        'max\_test\_acc': max_test_acc
    }

    if save_max:
        torch.save(checkpoint, os.path.join(out_dir, 'checkpoint\_max.pth'))

    torch.save(checkpoint, os.path.join(out_dir, 'checkpoint\_latest.pth'))

    print(args)
    print(out_dir)
    print(f'epoch ={epoch}, train\_loss ={train\_loss: .4f}, train\_acc ={train\_acc: .4f}, test\_loss ={test\_loss: .4f}, test\_acc ={test\_acc: .4f}, max\_test\_acc ={max\_test\_acc: .4f}')
    print(f'train speed ={train\_speed: .4f} images/s, test speed ={test\_speed: .4f} images/s')
    print(f'escape time = {(datetime.datetime.now() + datetime.timedelta(seconds=(time.time() - start\_time) \* (args.epochs - epoch))).strftime("%Y-%m-%d %H:%M:%S")}\n')

# 保存绘图用数据
net.eval()
# 注册钩子
output_layer = net.layer[-1] # 输出层
output_layer.v_seq = []
output_layer.s_seq = []
def save\_hook(m, x, y):
    m.v_seq.append(m.v.unsqueeze(0))
    m.s_seq.append(y.unsqueeze(0))

output_layer.register_forward_hook(save_hook)


with torch.no_grad():
    img, label = test_dataset[0]
    img = img.to(args.device)
    out_fr = 0.
    for t in range(args.T):
        encoded_img = encoder(img)
        out_fr += net(encoded_img)
    out_spikes_counter_frequency = (out_fr / args.T).cpu().numpy()
    print(f'Firing rate: {out\_spikes\_counter\_frequency}')

    output_layer.v_seq = torch.cat(output_layer.v_seq)
    output_layer.s_seq = torch.cat(output_layer.s_seq)
    v_t_array = output_layer.v_seq.cpu().numpy().squeeze()  # v\_t\_array[i][j]表示神经元i在j时刻的电压值
    np.save("v\_t\_array.npy",v_t_array)
    s_t_array = output_layer.s_seq.cpu().numpy().squeeze()  # s\_t\_array[i][j]表示神经元i在j时刻释放的脉冲,为0或1
    np.save("s\_t\_array.npy",s_t_array)

if name == ‘__main__’:
main()



Namespace(T=100, amp=True, b=64, data_dir=‘\mnist’, device=‘cuda:0’, epochs=50, j=2, lr=0.001, momentum=0.9, opt=‘adam’, out_dir=‘./logs’, resume=None, tau=2.0)
./logs\T100_b64_adam_lr0.001_amp
epoch =49, train_loss = 0.0138, train_acc = 0.9324, test_loss = 0.0146, test_acc = 0.9269, max_test_acc = 0.9282
train speed = 1504.1307 images/s, test speed = 2240.2271 images/s
escape time = 2024-03-22 15:13:23

Firing rate: [[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]]


【C:\Users\wx\AppData\Local\Programs\Python\Python37\Lib\site-packages\spikingjelly\activation\_based】



创建数据加载器

test_dataset = torchvision.datasets.MNIST(root=‘./data’, train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

批量预测

for imgs, labels in test_loader:
imgs = imgs.unsqueeze(1) # 确保图片有正确的维度
with torch.no_grad():
outputs = model(imgs)
predicted_labels = outputs.argmax(dim=1)
for i, label in enumerate(predicted_labels):
print(f’Predicted label: {label.item()}, True label: {labels[i].item()}')
333333333333333333333333333333333333333333333333333333333

或者从MNIST测试集中获取一张图片

test_dataset = torchvision.datasets.MNIST(root=‘./data’, train=False, download=True, transform=transform)
img, label = test_dataset[0] # 获取第一张图片及其标签
img = img.unsqueeze(0) # 增加批次维度

模型推理

with torch.no_grad():
output = model(img)

解析结果

predicted_label = output.argmax(dim=1)
print(f’Predicted label: {predicted_label.item()}, True label: {label}')


=========================


训练的main



import os
import time
import argparse
import sys
import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torch.cuda import amp
from torch.utils.tensorboard import SummaryWriter
import torchvision
import numpy as np

from spikingjelly.activation_based import neuron, encoding, functional, surrogate, layer

class SNN(nn.Module):
def __init__(self, tau):
super().init()

    self.layer = nn.Sequential(
        layer.Flatten(),
        layer.Linear(28 \* 28, 20, bias=False),
        neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan()),
        layer.Linear(20, 10, bias=False),
        neuron.LIFNode(tau=tau, surrogate_function=surrogate.ATan()),
        )

def forward(self, x: torch.Tensor):
    return self.layer(x)

def main():
‘’’
:return: None

* :ref:API in English <lif\_fc\_mnist.main-en>

… _lif_fc_mnist.main-cn:

使用全连接-LIF的网络结构,进行MNIST识别。\n
这个函数会初始化网络进行训练,并显示训练过程中在测试集的正确率。

* :ref:中文API <lif\_fc\_mnist.main-cn>

… _lif_fc_mnist.main-en:

The network with FC-LIF structure for classifying MNIST.\n
This function initials the network, starts trainingand shows accuracy on test dataset.
‘’’
parser = argparse.ArgumentParser(description=‘LIF MNIST Training’)
parser.add_argument(‘-T’, default=100, type=int, help=‘simulating time-steps’)
parser.add_argument(‘-device’, default=‘cuda:0’, help=‘device’)
parser.add_argument(‘-b’, default=64, type=int, help=‘batch size’)
parser.add_argument(‘-epochs’, default=100, type=int, metavar=‘N’,
help=‘number of total epochs to run’)
parser.add_argument(‘-j’, default=4, type=int, metavar=‘N’,
help=‘number of data loading workers (default: 4)’)
parser.add_argument(‘-data-dir’, type=str, help=‘root dir of MNIST dataset’)
parser.add_argument(‘-out-dir’, type=str, default=‘./logs’, help=‘root dir for saving logs and checkpoint’)
parser.add_argument(‘-resume’, type=str, help=‘resume from the checkpoint path’)
parser.add_argument(‘-amp’, action=‘store_true’, help=‘automatic mixed precision training’)
parser.add_argument(‘-opt’, type=str, choices=[‘sgd’, ‘adam’], default=‘adam’, help=‘use which optimizer. SGD or Adam’)
parser.add_argument(‘-momentum’, default=0.9, type=float, help=‘momentum for SGD’)
parser.add_argument(‘-lr’, default=1e-3, type=float, help=‘learning rate’)
parser.add_argument(‘-tau’, default=2.0, type=float, help=‘parameter tau of LIF neuron’)

args = parser.parse_args()
print(args)

net = SNN(tau=args.tau)

print(net)

net.to(args.device)

# 初始化数据加载器
train_dataset = torchvision.datasets.MNIST(
    root=args.data_dir,
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True
)
test_dataset = torchvision.datasets.MNIST(
    root=args.data_dir,
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

train_data_loader = data.DataLoader(
    dataset=train_dataset,
    batch_size=args.b,
    shuffle=True,
    drop_last=True,
    num_workers=args.j,
    pin_memory=True
)
test_data_loader = data.DataLoader(
    dataset=test_dataset,
    batch_size=args.b,
    shuffle=False,
    drop_last=False,
    num_workers=args.j,
    pin_memory=True
)

scaler = None
if args.amp:
    scaler = amp.GradScaler()

start_epoch = 0
max_test_acc = -1

optimizer = None
if args.opt == 'sgd':
    optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=args.momentum)
elif args.opt == 'adam':
    optimizer = torch.optim.Adam(net.parameters(), lr=args.lr)
else:
    raise NotImplementedError(args.opt)

if args.resume:
    checkpoint = torch.load(args.resume, map_location='cpu')
    net.load_state_dict(checkpoint['net'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    start_epoch = checkpoint['epoch'] + 1
    max_test_acc = checkpoint['max\_test\_acc']

out_dir = os.path.join(args.out_dir, f'T{args.T}\_b{args.b}\_{args.opt}\_lr{args.lr}')

if args.amp:
    out_dir += '\_amp'#是否使用混合精度

if not os.path.exists(out_dir):
    os.makedirs(out_dir)
    print(f'Mkdir {out\_dir}.')

with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt:
    args_txt.write(str(args))

writer = SummaryWriter(out_dir, purge_step=start_epoch)
with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt:
    args_txt.write(str(args))
    args_txt.write('\n')
    args_txt.write(' '.join(sys.argv))

encoder = encoding.PoissonEncoder()

for epoch in range(start_epoch, args.epochs):
    start_time = time.time()
    net.train()
    train_loss = 0
    train_acc = 0
    train_samples = 0
    for img, label in train_data_loader:
        optimizer.zero_grad()
        img = img.to(args.device)
        label = label.to(args.device)
        label_onehot = F.one_hot(label, 10).float()

        if scaler is not None:# 混合精度训练
            with amp.autocast():
                out_fr = 0.
                for t in range(args.T):
                    encoded_img = encoder(img)#这里必须把图片编码成T个批次,用泊松编码
                    out_fr += net(encoded_img)
                out_fr = out_fr / args.T
                # out\_fr是shape=[batch\_size, 10]的tensor
            # 记录整个仿真时长内,输出层的10个神经元的脉冲发放率
                loss = F.mse_loss(out_fr, label_onehot)
                
            # 损失函数为输出层神经元的脉冲发放频率,与真实类别的MSE
            # 这样的损失函数会使得:当标签i给定时,输出层中第i个神经元的脉冲发放频率趋近1,而其他神经元的脉冲发放频率趋近0
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            out_fr = 0.
            for t in range(args.T):
                encoded_img = encoder(img)#这里必须把图片编码成T个批次,用泊松编码
                out_fr += net(encoded_img)
            out_fr = out_fr / args.T
            loss = F.mse_loss(out_fr, label_onehot)
            loss.backward()
            optimizer.step()

        train_samples += label.numel()
        train_loss += loss.item() \* label.numel()
        # 正确率的计算方法如下。认为输出层中脉冲发放频率最大的神经元的下标i是分类结果
        train_acc += (out_fr.argmax(1) == label).float().sum().item()
        # 优化一次参数后,需要重置网络的状态,因为SNN的神经元是有“记忆”的
        functional.reset_net(net)

    train_time = time.time()
    train_speed = train_samples / (train_time - start_time)
    train_loss /= train_samples
    train_acc /= train_samples

    writer.add_scalar('train\_loss', train_loss, epoch)
    writer.add_scalar('train\_acc', train_acc, epoch)

    net.eval()
    test_loss = 0
    test_acc = 0
    test_samples = 0
    with torch.no_grad():
        for img, label in test_data_loader:
            img = img.to(args.device)
            label = label.to(args.device)
            label_onehot = F.one_hot(label, 10).float()
            out_fr = 0.
            for t in range(args.T):
                encoded_img = encoder(img)
                out_fr += net(encoded_img)
            out_fr = out_fr / args.T
            loss = F.mse_loss(out_fr, label_onehot)

            test_samples += label.numel()
            test_loss += loss.item() \* label.numel()
            test_acc += (out_fr.argmax(1) == label).float().sum().item()
            functional.reset_net(net)
    test_time = time.time()
    test_speed = test_samples / (test_time - train_time)
    test_loss /= test_samples
    test_acc /= test_samples
    writer.add_scalar('test\_loss', test_loss, epoch)
    writer.add_scalar('test\_acc', test_acc, epoch)

    save_max = False
    if test_acc > max_test_acc:
        max_test_acc = test_acc
        save_max = True

    checkpoint = {
        'net': net.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
        'max\_test\_acc': max_test_acc
    }

    if save_max:
        torch.save(checkpoint, os.path.join(out_dir, 'checkpoint\_max.pth'))

    torch.save(checkpoint, os.path.join(out_dir, 'checkpoint\_latest.pth'))

    print(args)
    print(out_dir)
    print(f'epoch ={epoch}, train\_loss ={train\_loss: .4f}, train\_acc ={train\_acc: .4f}, test\_loss ={test\_loss: .4f}, test\_acc ={test\_acc: .4f}, max\_test\_acc ={max\_test\_acc: .4f}')
    print(f'train speed ={train\_speed: .4f} images/s, test speed ={test\_speed: .4f} images/s')
    print(f'escape time = {(datetime.datetime.now() + datetime.timedelta(seconds=(time.time() - start\_time) \* (args.epochs - epoch))).strftime("%Y-%m-%d %H:%M:%S")}\n')

# 保存绘图用数据
net.eval()
# 注册钩子
output_layer = net.layer[-1] # 输出层
output_layer.v_seq = []
output_layer.s_seq = []
def save\_hook(m, x, y):
    m.v_seq.append(m.v.unsqueeze(0))
    m.s_seq.append(y.unsqueeze(0))

output_layer.register_forward_hook(save_hook)


with torch.no_grad():#预测的时候,使用没有梯度的
    img, label = test_dataset[0]
    img = img.to(args.device)

自我介绍一下,小编13年上海交大毕业,曾经在小公司待过,也去过华为、OPPO等大厂,18年进入阿里一直到现在。

深知大多数Python工程师,想要提升技能,往往是自己摸索成长或者是报班学习,但对于培训机构动则几千的学费,着实压力不小。自己不成体系的自学效果低效又漫长,而且极易碰到天花板技术停滞不前!

因此收集整理了一份《2024年Python开发全套学习资料》,初衷也很简单,就是希望能够帮助到想自学提升又不知道该从何学起的朋友,同时减轻大家的负担。

img

img

img

img

img

img

既有适合小白学习的零基础资料,也有适合3年以上经验的小伙伴深入学习提升的进阶课程,基本涵盖了95%以上前端开发知识点,真正体系化!

由于文件比较大,这里只是将部分目录大纲截图出来,每个节点里面都包含大厂面经、学习笔记、源码讲义、实战项目、讲解视频,并且后续会持续更新

如果你觉得这些内容对你有帮助,可以扫码获取!!!(备注Python)

到想自学提升又不知道该从何学起的朋友,同时减轻大家的负担。**

[外链图片转存中…(img-g6QHX2Mp-1712951842552)]

[外链图片转存中…(img-7065wQ4B-1712951842553)]

[外链图片转存中…(img-v6TzWxPp-1712951842553)]

[外链图片转存中…(img-cWgYNM66-1712951842554)]

img

img

既有适合小白学习的零基础资料,也有适合3年以上经验的小伙伴深入学习提升的进阶课程,基本涵盖了95%以上前端开发知识点,真正体系化!

由于文件比较大,这里只是将部分目录大纲截图出来,每个节点里面都包含大厂面经、学习笔记、源码讲义、实战项目、讲解视频,并且后续会持续更新

如果你觉得这些内容对你有帮助,可以扫码获取!!!(备注Python)

img
  • 29
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值