从名字就可以看出两者的区别
一个是模块信息,一个是参数,但还是想要更直观一点。
文字的说法可以看这里
这里展示的网络模型的代码来自https://github.com/Yanqi-Chen/Gradient-Rewiring
这里进行的修改就是打印出了对应的两个不同的信息:
import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.clock_driven import functional, layer, surrogate, neuron
from torchvision import transforms
class Cifar10Net(nn.Module):
def __init__(self, T=8, v_threshold=1.0, v_reset=0.0, tau=2.0, surrogate_function=surrogate.ATan()):
super().__init__()
self.train_times = 0
self.epochs = 0
self.max_test_acccuracy = 0
self.T = T
self.static_conv = nn.Sequential(
nn.Conv2d(3, 256, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(256),
)
self.conv = nn.Sequential(
neuron.LIFNode(v_threshold=v_threshold, v_reset=v_reset, tau=tau, surrogate_function=surrogate_function, detach_reset=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(256),
neuron.LIFNode(v_threshold=v_threshold, v_reset=v_reset, tau=tau, surrogate_function=surrogate_function, detach_reset=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(256),
neuron.LIFNode(v_threshold=v_threshold, v_reset=v_reset, tau=tau, surrogate_function=surrogate_function, detach_reset=True),
nn.MaxPool2d(2, 2), # 16 * 16
nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(256),
neuron.LIFNode(v_threshold=v_threshold, v_reset=v_reset, tau=tau, surrogate_function=surrogate_function, detach_reset=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(256),
neuron.LIFNode(v_threshold=v_threshold, v_reset=v_reset, tau=tau, surrogate_function=surrogate_function, detach_reset=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(256),
neuron.LIFNode(v_threshold=v_threshold, v_reset=v_reset, tau=tau, surrogate_function=surrogate_function, detach_reset=True),
nn.MaxPool2d(2, 2) # 8 * 8
)
self.fc = nn.Sequential(
nn.Flatten(),
layer.Dropout(0.5),
nn.Linear(256 * 8 * 8, 128 * 4 * 4, bias=False),
neuron.LIFNode(v_threshold=v_threshold, v_reset=v_reset, tau=tau, surrogate_function=surrogate_function, detach_reset=True),
nn.Linear(128 * 4 * 4, 100, bias=False),
neuron.LIFNode(v_threshold=v_threshold, v_reset=v_reset, tau=tau, surrogate_function=surrogate_function, detach_reset=True)
)
self.boost = nn.AvgPool1d(10, 10)
def forward(self, x):
x = self.static_conv(x)
out_spikes_counter = self.boost(self.fc(self.conv(x)).unsqueeze(1)).squeeze(1)
for _ in range(1, self.T):
out_spikes_counter += self.boost(self.fc(self.conv(x)).unsqueeze(1)).squeeze(1)
return out_spikes_counter
if __name__ == "__main__":
net = Cifar10Net()
# print(net)
print('named_modules:')
for name, module in net.named_modules():
print('name:{}, module {}'.format(name, module))
print('#####################################################')
print('named_parameters:')
for name, param in net.named_parameters():
print('name:{}, param {}'.format(name, param))
最后输出的结果大概时以下两种不同的

上面是一些模块的详细信息。
下面就是固定的权重
对模块的信息进行判断,可能还有中间变量可以存取,比如这里判断是否存在某种属性:
for name, module in net.named_modules():
if hasattr(module, 'monitor'):
spike_times[name] = 0
权重可以用来取出来采取不同的操作,比如优化:
BN_list = ['static_conv.1', 'conv.2', 'conv.5', 'conv.9', 'conv.12', 'conv.15']
for name, param in net.named_parameters():
if any(BN_name in name for BN_name in BN_list):
bn_params += [param]
ttl_cnt += param.numel()
else:
weight_params += [param]
w_cnt += param.numel()
ttl_cnt += param.numel()

2499

被折叠的 条评论
为什么被折叠?



