如题,pytorch网络所有参数设为0的demo
from torchvision.models import resnet18
from torch.nn import init
net = resnet18(pretrained=False)
for key in net.state_dict():
if key.split('.')[-1] == 'weight':
if 'conv' in key:
init.zeros_(net.state_dict()[key])
if 'bn' in key:
net.state_dict()[key][...] = 0
elif key.split('.')[-1] == 'bias':
net.state_dict()[key][...] = 0
for parameters in net.parameters():
print(parameters)