spikingjelly的最新版本,我当时使用的是20210330,中间存在onnx的包调用错误,难以解决。
通过调试20201221,发现这个版本的是好使的,但是也有一些是需要调试的。
1.首先应该训练出一个ANN模型出来,此时main函数中应该不填入任何信息,即
if __name__ == '__main__':
main()
# main('./log-cnn_mnist1622169436.7251258')
2.当训练出ANN模型后,会自动生成一个文件夹
如
./log-cnn_mnist1622169436.7251258
这个文件夹中存有大量的ANN和SNN模型
3.
utils.pytorch_ann2snn(model_name=model_name,
norm_tensor=norm_tensor,
test_data_loader=test_data_loader,
device=device,
T=T,
log_dir=log_dir,
config=config
)
这里面有好几个坑:
3.1
原来的代码中,第二个参数写的不是norm_tensor,好像是test_tensor,这个容易发现,按着ctrl 点击pytorch_ann2snn就能很快找到错误。
3.2
utils.py 里面有大量的function
其中有一个pytorch_ann2snn,这个function里面有一行是使用val_ann的
正确代码如下:
ann_acc = val_ann(net=parsed_ann, loss_function=nn.CrossEntropyLoss(),device=device, data_loader=test_data_loader)
原来的代码中,好像是少了一个loss_function的参数
3.3
还是val_ann的问题,但是我忘记了。。。。
反正最后是成功了
控制台输出代码如下:
All the temp files are saved to ./log-cnn_mnist1622169436.7251258
ann2snn config:
{
'simulation': {
'reset_to_zero': False, 'encoder': {
'possion': False}, 'avg_pool': {
'has_neuron': True}, 'max_pool': {
'if_spatial_avg': False, 'if_wta': False, 'momentum': None}}, 'parser': {
'robust_norm': True}}
Directly load model cnn_mnist.pkl
Using 120 pictures as norm set
Load best model for Model:cnn_mnist...
ANN Validating Accuracy:0.986
Save model to: ./log-cnn_mnist1622169436.7251258\parsed_cnn_mnist.pkl
Using robust normalization...
normalize with bias...
Print Parsed ANN model Structure:
Pytorch_Parser(
(network): Sequential(
(0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
(1): ReLU()
(2): AvgPool2d(kernel_size=2, stride=2, padding=0)
(3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
(4): ReLU()
(5): AvgPool2d(kernel_size=2, stride=2, padding=0)
(6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
(7): ReLU()
(8): AvgPool2d(kernel_size=2, stride=2, padding=0)
(9): Flatten(start_dim=1, end_dim=-1)
(10): Linear(in_features=32, out_features=10, bias=True)
(11): ReLU()
)
)
Save model to: ./log-cnn_mnist1622169436.7251258\normalized_cnn_mnist.pkl
Print Simulated SNN model Structure:
PyTorch_Converter(
(network): Sequential(
(0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
(1): IFNode(
v_threshold=1.0, v_reset=None, detach_reset=False
(surrogate_function): Sigmoid(alpha=1.0, spiking=True)
)
(2): AvgPool2d(kernel_size=2, stride=2, padding=0)
(3): IFNode(
v_threshold=1.0, v_reset=None, detach_reset=False
(surrogate_function): Sigmoid(alpha=1.0, spiking=True)
)
(4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
(5): IFNode(
v_threshold=1.0, v_reset=None, detach_reset=False
(surrogate_function): Sigmoid(alpha=1.0, spiking=True)
)
(6): AvgPool2d(kernel_size=2, stride=2, padding=0)
(7): IFNode(
v_threshold=1.0, v_reset=None, detach_reset=False
(surrogate_function): Sigmoid(alpha=1.0, spiking=True)
)
(8): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
(9): IFNode(
v_threshold=1.0, v_reset=None, detach_reset=False
(surrogate_function): Sigmoid(alpha=1.0, spiking=True)
)
(10): AvgPool2d(kernel_size=2, stride=2, padding=0)
(11): IFNode(
v_threshold=1.0, v_reset=None, detach_reset=False
(surrogate_function): Sigmoid(alpha=1.0, spiking=True)
)
(12): Flatten(start_dim=1, end_dim=-1)
(13): Linear(in_features=32, out_features=10, bias=True)
(14): IFNode(
v_threshold=1.0, v_reset=None, detach_reset=False
(surrogate_function): Sigmoid(alpha=1.0, spiking=True)
)
)
)
100%|██████████| 100/100 [00:00<00:00, 285.67it/s]
[SNN Simulating... 1.00%] Acc:0.990
100%|██████████| 100/100 [00:00<00:00, 374.17it/s]
[SNN Simulating... 2.00%] Acc:0.995
78%|███████▊ | 78/100 [00:00<00:00, 383.98it/s][SNN Simulating... 3.00%] Acc:0.990
100%|██████████| 100/100 [00:00<00:00, 383.84it/s]
100%|██████████| 100/100 [00:00<00:00, 387.40it/s]
[SNN Simulating... 4.00%] Acc:0.993
100%|██████████| 100/100 [00:00<00:00, 380.71it/s]
[SNN Simulating... 5.00%] Acc:0.992
77%|███████▋ | 77/100 [00:00<00:00, 386.11it/s][SNN Simulating... 6.00%] Acc:0.992