BindsNet使用实践

BindsNet是构建在PyTorch深度学习平台之上。 它用于模拟尖峰神经网络(SNNs),面向机器学习和强化学习。

本文主要使用BindsNet来生成MNIST图片的脉冲信息

首先,我们来看看我们将要使用到的BindsNet的重要组件:

Networks(用于创建一个网络)

Nodes(节点,就是一个个的神经元)

Connection(连接,用于节点之间进行连接)

Monitor(监视器,用于获取脉冲过程当中产生的脉冲数据)

首先加载MNIST数据

#加载MNIST数据.
dataset = MNIST(
    NullEncoder(),
    None,
    root=os.path.join("data", "MNIST"),
    download=True,
    transform=transforms.Compose(
        [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]
    ),
)

#将数据格式化为数据集
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)

其次创建NetWorks以及Nodes

# Create the network.
network = Network()

# Create and add input, output layers.
source_layer = Input(n=n1*n1)
target_layer = LIFNodes(n=n2*n2,refrac=2)
con_layer = LIFNodes(n=n3*n3,refrac=2)
output_layer = LIFNodes(n=n4*n4,refrac=2)

#定义分类层
class_layer = LIFNodes(n=nclass,refrac=2)

network.add_layer(
    layer=source_layer, name="A"
)
network.add_layer(
    layer=target_layer, name="B"
)
network.add_layer(
    layer=con_layer, name="C"
)
network.add_layer(
    layer=output_layer, name="D"
)

network.add_layer(
    layer=class_layer, name="E"
)

进一步创建连接Connection

#输入层到卷积层
w1 = conv1(n1)
# 创建连接
forward_connection = Connection(
    source=source_layer,
    target=target_layer,
    w=w1,  
)
#网络增加连接
network.add_connection(
    connection=forward_connection, source="A", target="B"
)

进一步创建监视器Monitor

target_monitor = Monitor(
    obj=class_layer,
    state_vars=("s", "v"),
    time=time*3,  
)

network.add_monitor(monitor=source_monitor, name="C")

最后run起来

for (i, datum) in enumerate(dataloader):
    image = datum["encoded_image"]
    label = datum["label"]
    input_data =  torch.tensor(np.tile(datum["encoded_image"].view(784),(time,1))).float()
    key = label[0]
    print(key)
    inputs = {"A": input_data}
    network.run(inputs=inputs, time=time)
    spikes = {
        "C": source_monitor.get("s"), 
        "D": target_monitor.get("s")
    }
    voltages = {"D": target_monitor.get("v")}
    #更新输出层到分类层的权重,以发出脉冲
    for k in source_monitor.get("s").numpy():
        unit = np.argwhere(k[0]==True).flatten()
        if len(unit>0):
            refreshWeight('D','E', unit)
        step = step +1
    network.run(inputs=inputs, time=time)
    np.savetxt('spike_jjjj.txt', network.connections['D', 'E'].w, fmt='%3.2f', delimiter=',')
    #更新分类层自连接,修改
    for k in target_monitor.get("s").numpy():
        unit = np.argwhere(k[0]==True).flatten()

最终我们根据自己计划的模型,例如节点之间的联系方式,节点数量,连接方式,得到脉冲信息

 

  • 0
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值