TactileSGNet代码学习

TactileSGNet代码学习

TactileSGNet: A Spiking Graph Neural Network for Event-based
Tactile Object Recognition

Fuqiang Gu, Weicong Sng, Tasbolat Taunyazov, and Harold Soh
Dept. of Computer Science, School of Computing,
National University of Singapore

Github:https://github.com/clear-nus/TactileSGNet

论文主要内容

  • 研究领域:基于触觉的家用物体识别(36种多分类)

  • 所用技术:SNN+GNN

  • 方法概览:

    image-20201015101219663

    输入:机器人有39个触点,对这39个触点进行监测250个时间片,形成250个图输入

    处理:SNN封装的GNN+MLP

    输出:label

代码主要内容

  1. 可复现性,出了seed一致之外

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    

    作者还用了cudnn的API来固定cuda的seed和卷积底层算法。(默认会自动选择合适的底层卷积算法)

  2. 模型架构

    对每个时间片的激活点,构成一个图

    对每个图进行SNN封装后的卷积和MLP

    class TactileSGNet(nn.Module):
        def forward(self, input, training = True):
            ………………
            for step in range(time_window): 
                x = inputs[step].squeeze()
                x = x.to(self.device)
                graph_data = self.graph(x)
                x = graph_data.x.to(self.device) 
                edge_idxs = graph_data.edge_index.to(self.device)
    
                c1_mem, c1_spike = mem_update_conv(self.conv1, x, edge_idxs, c1_mem, c1_spike)
                x = c1_spike
                x = x.view(-1)
    
                h1_mem, h1_spike = mem_update(self.fc1, x, h1_mem, h1_spike)
                h1_sumspike += h1_spike
                h2_mem, h2_spike = mem_update(self.fc2, h1_spike, h2_mem,h2_spike)
                h2_sumspike += h2_spike
                h3_mem, h3_spike = mem_update(self.fc3, h2_spike, h3_mem, h3_spike)
                h3_sumspike += h3_spike
    
            outputs = h3_sumspike / time_window
            return outputs
    

    基于LIF神经元的激活和传播

    Direct Training for Spiking Neural Networks: Faster, Larger, Better.,AAAI,2019,paper,code

    Wu, Yujie, Lei Deng, Guoqi Li, Jun Zhu, and Luping Shi.

    class ActFun(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input):
            ctx.save_for_backward(input)
            return input.gt(thresh).float()
    
        @staticmethod
        def backward(ctx, grad_output):
            input, = ctx.saved_tensors
            grad_input = grad_output.clone()
            temp = abs(input - thresh) < lens
            return grad_input * temp.float()
    
    act_fun = ActFun.apply
    # membrane potential update, for GCN
    def mem_update_conv(ops, x, edge_idxs, mem, spike):
        mem = mem * decay * (1. - spike) + ops(x, edge_idxs)
        spike = act_fun(mem) # act_fun : approximation firing function
        return mem, spike
    
    def mem_update(ops, x, mem, spike):
        mem = mem * decay * (1. - spike) + ops(x)
        spike = act_fun(mem)
        return mem, spike
    

    记忆值*衰减率* +当下脉冲值

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值