BindsNet是构建在PyTorch深度学习平台之上。 它用于模拟尖峰神经网络(SNNs),面向机器学习和强化学习。
本文主要使用BindsNet来生成MNIST图片的脉冲信息
首先,我们来看看我们将要使用到的BindsNet的重要组件:
Networks(用于创建一个网络)
Nodes(节点,就是一个个的神经元)
Connection(连接,用于节点之间进行连接)
Monitor(监视器,用于获取脉冲过程当中产生的脉冲数据)
首先加载MNIST数据
#加载MNIST数据.
dataset = MNIST(
NullEncoder(),
None,
root=os.path.join("data", "MNIST"),
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]
),
)
#将数据格式化为数据集
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
其次创建NetWorks以及Nodes
# Create the network.
network = Network()
# Create and add input, output layers.
source_layer = Input(n=n1*n1)
target_layer = LIFNodes(n=n2*n2,refrac=2)
con_layer = LIFNodes(n=n3*n3,refrac=2)
output_layer = LIFNodes(n=n4*n4,refrac=2)
#定义分类层
class_layer = LIFNodes(n=nclass,refrac=2)
network.add_layer(
layer=source_layer, name="A"
)
network.add_layer(
layer=target_layer, name="B"
)
network.add_layer(
layer=con_layer, name="C"
)
network.add_layer(
layer=output_layer, name="D"
)
network.add_layer(
layer=class_layer, name="E"
)
进一步创建连接Connection
#输入层到卷积层
w1 = conv1(n1)
# 创建连接
forward_connection = Connection(
source=source_layer,
target=target_layer,
w=w1,
)
#网络增加连接
network.add_connection(
connection=forward_connection, source="A", target="B"
)
进一步创建监视器Monitor
target_monitor = Monitor(
obj=class_layer,
state_vars=("s", "v"),
time=time*3,
)
network.add_monitor(monitor=source_monitor, name="C")
最后run起来
for (i, datum) in enumerate(dataloader):
image = datum["encoded_image"]
label = datum["label"]
input_data = torch.tensor(np.tile(datum["encoded_image"].view(784),(time,1))).float()
key = label[0]
print(key)
inputs = {"A": input_data}
network.run(inputs=inputs, time=time)
spikes = {
"C": source_monitor.get("s"),
"D": target_monitor.get("s")
}
voltages = {"D": target_monitor.get("v")}
#更新输出层到分类层的权重,以发出脉冲
for k in source_monitor.get("s").numpy():
unit = np.argwhere(k[0]==True).flatten()
if len(unit>0):
refreshWeight('D','E', unit)
step = step +1
network.run(inputs=inputs, time=time)
np.savetxt('spike_jjjj.txt', network.connections['D', 'E'].w, fmt='%3.2f', delimiter=',')
#更新分类层自连接,修改
for k in target_monitor.get("s").numpy():
unit = np.argwhere(k[0]==True).flatten()
最终我们根据自己计划的模型,例如节点之间的联系方式,节点数量,连接方式,得到脉冲信息