理论
state_dict()
功能:获取模型当前的参数,以一个有序字典形式返回。 这个有序字典中,key 是各层参数名,value 就是参数。
代码
# coding: utf-8
import torch.nn as nn
import torch.nn.functional as F
# ----------------------------------- state_dict
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 1, 3)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(1 * 3 * 3, 2)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = x.view(-1, 1 * 3 * 3)
x = F.relu(self.fc1(x))
return x
net = Net()
# 获取网络当前参数
net_state_dict = net.state_dict()
print('net_state_dict类型:', type(net_state_dict))
print('net_state_dict管理的参数: ', net_state_dict.keys())
print('net_state_dict:', net_state_dict)
for key, value in net_state_dict.items():
print('参数名: ', key, '\t大小: ', value.shape)
结果
net_state_dict类型: <class 'collections.OrderedDict'>
net_state_dict管理的参数: odict_keys(['conv1.weight', 'conv1.bias', 'fc1.weight', 'fc1.bias'])
net_state_dict: OrderedDict([
('conv1.weight', tensor([[[[ 0.1847, 0.1356, 0.0460],
[-0.0917, -0.1756, -0.0554],
[-0.1254, -0.1861, -0.0111]],
[[ 0.0801, 0.0135, -0.1842],
[-0.1620, 0.0984, 0.0845],
[ 0.0077, 0.0937, -0.0286]],
[[ 0.0141, 0.0263, 0.1229],
[ 0.1565, 0.0061, 0.0611],
[-0.0304, 0.1011, 0.1769]]]])),
('conv1.bias', tensor([0.0328])),
('fc1.weight', tensor([[-0.0093, -0.0534, -0.2707, -0.2470, 0.2448, 0.2793, -0.2237, 0.1177, 0.1881],
[-0.1974, -0.2257, 0.0995, -0.2708, 0.3137, -0.0334, 0.2609, -0.0878,-0.1505]])),
('fc1.bias', tensor([0.2143, 0.0708]))])
参数名: conv1.weight 大小: torch.Size([1, 3, 3, 3])
参数名: conv1.bias 大小: torch.Size([1])
参数名: fc1.weight 大小: torch.Size([2, 9])
参数名: fc1.bias 大小: torch.Size([2])