目录
在 PyTorch 中,你可以使用如下方式获取神经网络中的某一部分参数:
1. 使用 Module.named_parameters()
函数,这将返回一个迭代器,包含网络中所有可学习参数的名称和数值。例如:
import torch
import torch.nn as nn
# 定义一个简单的神经网络
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 3)
self.conv2 = nn.Conv2d(6, 16, 3)
self.fc1 = nn.Linear(16 * 6 * 6, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
# 创建网络实例并获取参数
net = Net()
for name, param in net.named_parameters():
print(name, param.size())
输出结果如下:
conv1.weight torch.Size