import torch
import torch.nn as nn
class mymodule(nn.Module):
def __init__(self):
super(mymodule,self).__init__()
self.linear=nn.Linear(2,3)
self.relu=nn.ReLU()
def forward(self,x):
x=self.linear(x)
x=self.relu(x)
return x
model=mymodule()
print("模型参数:",(model.parameters()))
for param in model.parameters():
print("参数类型:",type(param),"参数大小:",param.size())
查看网络模型参数
print("模型参数:",list((model.parameters())))
torch.nn.Module中的.parameters()方法使用
model.parameters()方法返回的是一个生成器generator,每一个元素是从开头到结尾的参数,parameters没有对应的key名称,是一个由纯参数组成的generator,而state_dict是一个字典,包含了一个key。
具体可以参考:
https://blog.csdn.net/qq_27825451/article/details/95888267