spikingjelly中的ANN2SNN程序20201221是好使的

博客讲述了作者在使用spikingjelly不同版本进行ANN到SNN转换时遇到的问题,强调20201221版本相对可用。主要问题包括:ONNX包调用错误、代码中参数错误(如将norm_tensor误写为test_tensor)、utils.py中pytorch_ann2snn函数缺失loss_function参数等。最终,作者成功解决了这些问题并得到了转换结果。
摘要由CSDN通过智能技术生成

spikingjelly的最新版本,我当时使用的是20210330,中间存在onnx的包调用错误,难以解决。

通过调试20201221,发现这个版本的是好使的,但是也有一些是需要调试的。
1.首先应该训练出一个ANN模型出来,此时main函数中应该不填入任何信息,即

if __name__ == '__main__':
    main()
    # main('./log-cnn_mnist1622169436.7251258')

2.当训练出ANN模型后,会自动生成一个文件夹

./log-cnn_mnist1622169436.7251258

这个文件夹中存有大量的ANN和SNN模型
3.

 utils.pytorch_ann2snn(model_name=model_name,
                              norm_tensor=norm_tensor,
                              test_data_loader=test_data_loader,
                              device=device,
                              T=T,
                              log_dir=log_dir,
                              config=config
                              )

这里面有好几个坑:
3.1
原来的代码中,第二个参数写的不是norm_tensor,好像是test_tensor,这个容易发现,按着ctrl 点击pytorch_ann2snn就能很快找到错误。
3.2
utils.py 里面有大量的function
其中有一个pytorch_ann2snn,这个function里面有一行是使用val_ann的
正确代码如下:

ann_acc = val_ann(net=parsed_ann, loss_function=nn.CrossEntropyLoss(),device=device, data_loader=test_data_loader)

原来的代码中,好像是少了一个loss_function的参数
3.3
还是val_ann的问题,但是我忘记了。。。。

反正最后是成功了
控制台输出代码如下:

All the temp files are saved to  ./log-cnn_mnist1622169436.7251258
ann2snn config:
	 {
   'simulation': {
   'reset_to_zero': False, 'encoder': {
   'possion': False}, 'avg_pool': {
   'has_neuron': True}, 'max_pool': {
   'if_spatial_avg': False, 'if_wta': False, 'momentum': None}}, 'parser': {
   'robust_norm': True}}
Directly load model cnn_mnist.pkl
Using 120 pictures as norm set
Load best model for Model:cnn_mnist...
ANN Validating Accuracy:0.986
Save model to: ./log-cnn_mnist1622169436.7251258\parsed_cnn_mnist.pkl
Using robust normalization...
normalize with bias...
Print Parsed ANN model Structure:
Pytorch_Parser(
  (network): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (7): ReLU()
    (8): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (9): Flatten(start_dim=1, end_dim=-1)
    (10): Linear(in_features=32, out_features=10, bias=True)
    (11): ReLU()
  )
)
Save model to: ./log-cnn_mnist1622169436.7251258\normalized_cnn_mnist.pkl
Print Simulated SNN model Structure:
PyTorch_Converter(
  (network): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
    (1): IFNode(
      v_threshold=1.0, v_reset=None, detach_reset=False
      (surrogate_function): Sigmoid(alpha=1.0, spiking=True)
    )
    (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (3): IFNode(
      v_threshold=1.0, v_reset=None, detach_reset=False
      (surrogate_function): Sigmoid(alpha=1.0, spiking=True)
    )
    (4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (5): IFNode(
      v_threshold=1.0, v_reset=None, detach_reset=False
      (surrogate_function): Sigmoid(alpha=1.0, spiking=True)
    )
    (6): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (7): IFNode(
      v_threshold=1.0, v_reset=None, detach_reset=False
      (surrogate_function): Sigmoid(alpha=1.0, spiking=True)
    )
    (8): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    (9): IFNode(
      v_threshold=1.0, v_reset=None, detach_reset=False
      (surrogate_function): Sigmoid(alpha=1.0, spiking=True)
    )
    (10): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (11): IFNode(
      v_threshold=1.0, v_reset=None, detach_reset=False
      (surrogate_function): Sigmoid(alpha=1.0, spiking=True)
    )
    (12): Flatten(start_dim=1, end_dim=-1)
    (13): Linear(in_features=32, out_features=10, bias=True)
    (14): IFNode(
      v_threshold=1.0, v_reset=None, detach_reset=False
      (surrogate_function): Sigmoid(alpha=1.0, spiking=True)
    )
  )
)
100%|██████████| 100/100 [00:00<00:00, 285.67it/s]
[SNN Simulating... 1.00%] Acc:0.990
100%|██████████| 100/100 [00:00<00:00, 374.17it/s]
[SNN Simulating... 2.00%] Acc:0.995
 78%|███████▊  | 78/100 [00:00<00:00, 383.98it/s][SNN Simulating... 3.00%] Acc:0.990
100%|██████████| 100/100 [00:00<00:00, 383.84it/s]
100%|██████████| 100/100 [00:00<00:00, 387.40it/s]
[SNN Simulating... 4.00%] Acc:0.993
100%|██████████| 100/100 [00:00<00:00, 380.71it/s]
[SNN Simulating... 5.00%] Acc:0.992
 77%|███████▋  | 77/100 [00:00<00:00, 386.11it/s][SNN Simulating... 6.00%] Acc:0.992
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值