TactileSGNet代码学习
TactileSGNet: A Spiking Graph Neural Network for Event-based
Tactile Object RecognitionFuqiang Gu, Weicong Sng, Tasbolat Taunyazov, and Harold Soh
Dept. of Computer Science, School of Computing,
National University of SingaporeGithub:https://github.com/clear-nus/TactileSGNet
论文主要内容
-
研究领域:基于触觉的家用物体识别(36种多分类)
-
所用技术:SNN+GNN
-
方法概览:
输入:机器人有39个触点,对这39个触点进行监测250个时间片,形成250个图输入
处理:SNN封装的GNN+MLP
输出:label
代码主要内容
-
可复现性,出了seed一致之外
torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False
作者还用了cudnn的API来固定cuda的seed和卷积底层算法。(默认会自动选择合适的底层卷积算法)
-
模型架构
对每个时间片的激活点,构成一个图
对每个图进行SNN封装后的卷积和MLP
class TactileSGNet(nn.Module): def forward(self, input, training = True): ……………… for step in range(time_window): x = inputs[step].squeeze() x = x.to(self.device) graph_data = self.graph(x) x = graph_data.x.to(self.device) edge_idxs = graph_data.edge_index.to(self.device) c1_mem, c1_spike = mem_update_conv(self.conv1, x, edge_idxs, c1_mem, c1_spike) x = c1_spike x = x.view(-1) h1_mem, h1_spike = mem_update(self.fc1, x, h1_mem, h1_spike) h1_sumspike += h1_spike h2_mem, h2_spike = mem_update(self.fc2, h1_spike, h2_mem,h2_spike) h2_sumspike += h2_spike h3_mem, h3_spike = mem_update(self.fc3, h2_spike, h3_mem, h3_spike) h3_sumspike += h3_spike outputs = h3_sumspike / time_window return outputs
基于LIF神经元的激活和传播
Direct Training for Spiking Neural Networks: Faster, Larger, Better.,AAAI,2019,paper,code
Wu, Yujie, Lei Deng, Guoqi Li, Jun Zhu, and Luping Shi.
class ActFun(torch.autograd.Function): @staticmethod def forward(ctx, input): ctx.save_for_backward(input) return input.gt(thresh).float() @staticmethod def backward(ctx, grad_output): input, = ctx.saved_tensors grad_input = grad_output.clone() temp = abs(input - thresh) < lens return grad_input * temp.float() act_fun = ActFun.apply # membrane potential update, for GCN def mem_update_conv(ops, x, edge_idxs, mem, spike): mem = mem * decay * (1. - spike) + ops(x, edge_idxs) spike = act_fun(mem) # act_fun : approximation firing function return mem, spike def mem_update(ops, x, mem, spike): mem = mem * decay * (1. - spike) + ops(x) spike = act_fun(mem) return mem, spike
记忆值*衰减率* +当下脉冲值