在PyTorch中,可以使用parameters函数来获取模型中的所有可学习参数。以下是一个示例:
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 1)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
model = MyModel()
params = list(model.parameters())
在这个示例中,我们首先定义了一个包含两个线性层的神经网络,然后通过list(model.parameters())获取了模型中的所有可学习参数。这些参数存储在一个Python列表中,可以用于进行优化器的初始化和模型的保存和加载。