继承MODULE类来构造模型
定义MLP类重载module类的__init__ 和forward函数,分别用于创建模型和进行前向计算,无需定义反向传播函数,
系统会通过自动求梯度,生成backward函数
import torch
import torch.nn as nn
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.hidden=nn.Linear(784,256)
self.act=nn.ReLU()
self.output=nn.Linear(256,10)
def forward(self,X):
a=self.act(self.hidden(X))
return self.output(a)
#实例化网路
X=torch.rand(2,784)
net=MLP()
print(net)
net(X)
MLP(
(hidden): Linear(in_features=784, out_features=256, bias=True)
(act): ReLU()
(output): Linear(in_features=256, out_features=10, bias=True)
)
tensor([[-0.1568, -0.1277, 0.1709, 0.1437, 0.0183, 0.2180, -0.1757, -0.2499,
0.0646, 0.0176],
[-0.0939, -0.1590, 0.0674, 0.2006, -0.0693, 0.1233, -0.1146, -0.2737,
-0.0184, -0.0989]], grad_fn=<AddmmBackward>)
MODULE子类
(1)Sequential 便于简单的前向计算串联各层,具体实现细节:
class MySequential(nn.Module):
def __init__(self,*args):
super().__init__()
if len(args)==1 and isinstance(args[0],OrderedDict):
for key,module in args[0].items:
self.add_module(key,module)
else:
for idx,module in enumerate(args):
self.add_module(str(idx),module)
def forward(self,input):
for module in self._modules.values():
input=module(input)
return input
net=MySequential(nn.Linear(784,256),nn.ReLU(),nn.Linear(256,10))
print(net)
net(X)
MySequential(
(0): Linear(in_features=784, out_features=256, bias=True)
(1): ReLU()
(2): Linear(in_features=256, out_features=10, bias=True)
)
tensor([[-0.0798, 0.1060, -0.0161, -0.0773, -0.0844, 0.2117, 0.2107, -0.0883,
0.1431, -0.0817],
[ 0.0164, 0.0234, -0.0436, -0.0638, -0.0055, 0.2168, 0.0212, -0.2043,
0.0989, -0.0613]], grad_fn=<AddmmBackward>)
访问模型参数
print(type(net.named_parameters())) #查看类型
<class 'generator'>
for name,param in net.named_parameters():
print(name,param.size())
0.weight torch.Size([256, 784])
0.bias torch.Size([256])
2.weight torch.Size([10, 256])
2.bias torch.Size([10])
初始化模型参数
Pytorch中的nn.Module模块已经采用较为合理的初始化策略,没有特殊情况一般不会再进行。Pytorch提供了init模块进行多种预设的初始化方法
# 看例子
import torch.nn.init as init
for name,param in net.named_parameters():
# 对weight
if 'weight' in name:
init.normal_(param,mean=0,std=0.01)
print(name,param.data)
if 'bias' in name:
init.constant_(param,val=0)
print(name,param.data)
0.weight tensor([[-7.3670e-03, -4.5059e-03, 2.5534e-03, ..., -1.4920e-02,
2.9664e-02, 9.6315e-03],
[ 3.3703e-03, -1.7920e-02, -1.4533e-02, ..., -7.4964e-04,
-1.1880e-02, 6.7410e-03],
[-1.5851e-02, -7.5730e-04, -1.4035e-03, ..., -1.1753e-02,
-8.3762e-03, 6.4255e-03],
...,
[-8.0277e-04, 2.1240e-02, 6.7807e-03, ..., 6.3259e-06,
-1.1621e-02, 4.4899e-03],
[-4.8244e-03, -7.8703e-03, 6.5461e-03, ..., -3.9733e-03,
-8.0912e-04, -8.6507e-03],
[-1.6352e-02, 2.9723e-03, 2.4309e-03, ..., -1.0128e-02,
3.5137e-03, -1.3541e-02]])
0.bias tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
2.weight tensor([[ 5.3187e-04, 1.5396e-02, -9.6198e-03, ..., 1.9582e-02,
1.2500e-02, 8.8209e-05],
[-7.1357e-03, -3.1403e-03, 5.3357e-03, ..., 6.8758e-03,
-6.4421e-03, 5.9430e-03],
[ 2.2387e-02, 5.2646e-03, -3.2153e-03, ..., 1.7143e-02,
1.2387e-02, -2.1117e-02],
...,
[-6.0646e-03, -7.9573e-03, -2.3083e-03, ..., -2.7555e-03,
5.7004e-04, -5.2114e-03],
[-1.9535e-02, -9.3469e-03, 2.1729e-02, ..., 2.7975e-02,
7.1541e-04, 2.3038e-03],
[-1.9913e-02, -1.9769e-03, 1.0276e-02, ..., 7.5271e-03,
-1.1302e-02, 5.0707e-03]])
2.bias tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
自定义初始化方法
我们先看torch.nn.init.normal_的具体实现
def normal_(tensor,mean=0,std=1):
with torch.no_grad():
return tensor.normal_(mean,std)
#将权重的一半概率初始化为0,另外一半初始化为[-10,-5],[5,10]两个区间的均匀分布
def init_weight_(tensor):
with torch.no_grad():
tensor.uniform_(-10,10)
tensor*=(tensor.abs()>=5).float()
for name,param in net.named_parameters():
# 对weight
if 'weight' in name:
init_weight_(param)
print(name,param.data)
0.weight tensor([[-0.0000, 0.0000, -0.0000, ..., -9.5116, -0.0000, 0.0000],
[-5.8680, 0.0000, 0.0000, ..., 0.0000, -0.0000, 9.0683],
[ 9.6953, -0.0000, 0.0000, ..., -5.2760, -0.0000, -8.7688],
...,
[ 0.0000, 6.1618, -0.0000, ..., -8.2309, -5.0961, 6.9231],
[ 7.7732, -0.0000, -7.7075, ..., 0.0000, -5.9438, -9.6617],
[ 8.5203, -0.0000, 6.8366, ..., -0.0000, -6.4314, -7.1154]])
2.weight tensor([[-0.0000, -6.0745, 6.4782, ..., 9.5190, -0.0000, 0.0000],
[ 0.0000, -7.1742, -0.0000, ..., -0.0000, 7.9957, -0.0000],
[-8.2802, -0.0000, 6.7018, ..., -0.0000, 0.0000, -0.0000],
...,
[-0.0000, 0.0000, -8.1016, ..., -0.0000, 7.2197, -8.4890],
[ 0.0000, -9.3670, -0.0000, ..., 6.4142, 8.5989, 9.9789],
[ 0.0000, 0.0000, 9.7598, ..., -0.0000, 6.3064, -5.4405]])
共享参数
Module类的forward函数里多次调用同一个层
#eg
linear=nn.Linear(1,1,bias=False)
net=nn.Sequential(linear,linear)
print(net)
Sequential(
(0): Linear(in_features=1, out_features=1, bias=False)
(1): Linear(in_features=1, out_features=1, bias=False)
)
可见第一层和第二层共享参数
读取和存储
(1)读取TENSOR
x=torch.ones(3)
torch.save(x,'x.pt')#保存张量
x1=torch.load('x.pt')#加载张量
print(x1)
y=torch.zeros(4)
torch.save([x,y],'xy.pt')
xy_list=torch.load('xy.pt')#存储一个tensor列表,并读回内存
print(xy_list)
torch.save({'x':x,'y':y},'xy_dict.pt')
xy=torch.load('xy_dict.pt')
print(xy)
tensor([1., 1., 1.])
[tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]
{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}
小结:torch.save 可以保存各种对象,包括模型,张量和字典 torch.load使用pickle unpicke 工具将pickle对象文件反序列化为内存
(2) 读写模型
方式1:state_dict方式(推荐);这是一个从参数名隐射到参数Tensor的字典对象,只有具有可学习参数的才有,优化器中也具有一个。
#利用前面构建的MLP举例
net=MLP()
net.state_dict()
OrderedDict([('hidden.weight',
tensor([[-0.0136, -0.0110, 0.0297, ..., 0.0014, -0.0208, -0.0321],
[-0.0057, -0.0211, 0.0135, ..., 0.0199, -0.0014, 0.0096],
[ 0.0280, 0.0117, -0.0006, ..., 0.0331, -0.0080, -0.0237],
...,
[-0.0246, -0.0046, -0.0349, ..., 0.0085, 0.0017, 0.0251],
[-0.0058, -0.0061, -0.0205, ..., -0.0287, 0.0247, 0.0005],
[-0.0143, 0.0321, 0.0212, ..., -0.0084, 0.0244, 0.0306]])),
('hidden.bias',
tensor([ 3.1305e-02, -6.9119e-03, 2.0941e-02, -2.2305e-02, 2.5125e-02,
-1.3433e-02, 3.1011e-02, -3.1288e-03, -3.4262e-02, 4.4943e-03,
3.1427e-02, 1.9461e-02, -3.3079e-02, 2.2021e-02, -1.4982e-02,
-1.7399e-02, 2.1805e-02, 3.1871e-02, 1.2301e-02, 3.0461e-03,
3.5593e-03, -8.0371e-03, -2.8687e-02, -2.5229e-02, 9.3432e-03,
-2.4308e-02, -1.0408e-02, -2.7482e-02, 8.1858e-03, -2.6377e-02,
-7.5215e-03, -8.7385e-03, 1.5541e-02, 3.3680e-02, 2.5204e-02,
1.5183e-02, 2.7199e-02, -1.4679e-02, -3.2512e-02, -3.2054e-02,
-2.8799e-02, -1.8439e-02, -1.1006e-02, 3.0108e-02, 2.2955e-02,
-5.9173e-03, -1.4669e-02, 4.8523e-03, 2.0565e-02, -2.2397e-02,
2.7658e-02, 6.3019e-03, -5.1514e-03, -7.8413e-03, -1.8985e-02,
-7.1641e-04, -5.4482e-03, -1.0629e-02, -1.4976e-02, 2.2032e-02,
-2.0785e-02, 2.8934e-02, 6.6899e-03, 3.4973e-02, 1.8431e-02,
-1.9760e-02, 2.1899e-02, -9.4859e-03, 3.1312e-03, -1.2290e-02,
-2.8553e-02, -1.7020e-02, -1.7534e-02, 3.0155e-02, -5.1623e-03,
2.6113e-03, 1.1324e-02, 2.6100e-02, -1.2954e-02, -1.0702e-02,
-6.4077e-03, -3.4773e-02, -4.5309e-03, -2.4982e-02, -3.3974e-02,
-3.0895e-03, 2.8445e-02, 1.6657e-02, -3.3164e-02, -9.9130e-03,
-3.0933e-02, -1.5069e-02, 3.0788e-02, -3.0066e-02, 1.8212e-02,
2.6435e-02, -2.5778e-02, 1.2167e-02, 2.1587e-02, -1.9604e-02,
1.0127e-02, 1.1425e-02, -1.1706e-02, -2.5564e-02, -3.4525e-02,
-2.4434e-02, 2.2406e-02, -3.1730e-02, 3.1080e-02, 1.3585e-03,
2.0328e-02, -2.7836e-02, -9.0502e-03, 2.8293e-02, 5.3455e-04,
-2.6237e-02, 2.2733e-02, -2.7324e-02, 2.7153e-02, 3.3261e-02,
3.2525e-02, 2.3496e-02, -1.2033e-02, -9.2284e-03, -2.3265e-02,
-6.9208e-04, 2.9772e-02, 2.4725e-02, -1.0888e-02, -5.4415e-03,
1.3602e-02, -5.0255e-03, -8.1237e-03, 1.8826e-02, 1.2347e-02,
-1.7835e-02, 5.7540e-03, -8.7569e-03, 3.5205e-02, -3.3670e-02,
-3.9439e-03, -2.2330e-02, -7.5736e-03, -5.0021e-03, 2.8512e-02,
2.9699e-03, -2.0811e-02, 3.0332e-02, 3.2435e-03, -1.1954e-02,
-2.4728e-02, 1.0801e-02, -1.6735e-02, 1.7002e-02, 1.7940e-02,
3.8781e-03, -2.1850e-02, -2.1234e-02, 2.0817e-02, 9.5619e-03,
2.4271e-02, 2.5589e-05, 1.7797e-02, -1.2524e-02, 6.4133e-03,
-6.3742e-03, -6.8082e-03, 4.8706e-03, 1.0436e-02, -3.4496e-02,
-2.0085e-02, 3.0331e-02, 3.4712e-02, -1.1799e-02, 7.4912e-03,
-2.0683e-02, -4.9508e-03, -9.4079e-03, 1.9821e-02, 1.2626e-02,
3.3640e-02, -7.0766e-03, -1.6350e-02, 2.7539e-02, -3.5491e-03,
-3.2590e-02, -1.9969e-02, 3.1202e-02, -3.4139e-02, 2.4393e-02,
1.4785e-02, -6.9166e-03, 3.1012e-02, 2.2647e-02, 2.6756e-02,
6.0994e-03, 2.0769e-02, -8.6712e-03, -2.8026e-02, -5.1754e-03,
-5.6501e-03, 2.0725e-03, 4.0739e-03, -9.4127e-03, 2.4582e-02,
1.5492e-02, -1.4416e-02, -2.5642e-02, -8.5237e-03, -9.0769e-04,
1.7156e-02, 2.5980e-02, -6.8641e-04, 3.5932e-03, 1.2633e-02,
-2.8234e-02, -2.8589e-02, -4.4682e-03, -3.0578e-02, 1.5988e-02,
-9.0113e-03, 7.1784e-03, 1.2726e-02, -2.0503e-02, 2.5928e-02,
8.7758e-03, -3.5033e-02, 1.8163e-02, -2.9047e-02, -1.8063e-02,
2.5760e-02, 6.9178e-03, -2.3813e-03, -1.6222e-02, -1.9902e-02,
6.9884e-03, 2.9882e-02, 2.0432e-02, 3.2332e-02, -2.1999e-02,
1.0688e-02, -2.9595e-02, 7.1982e-03, -2.0782e-02, -6.7366e-03,
9.8019e-03, -1.1958e-02, -2.6838e-02, -2.5685e-02, -2.0283e-03,
-1.0448e-02, -2.2837e-02, -2.5476e-02, -8.7900e-03, 1.9235e-02,
-1.1392e-02])),
('output.weight',
tensor([[ 3.8825e-02, 1.8576e-02, -4.0230e-02, ..., 2.9058e-02,
3.0307e-02, 3.8201e-02],
[ 4.5807e-02, -2.3931e-02, 2.2467e-02, ..., 1.6682e-02,
3.3030e-02, -5.9647e-02],
[-3.0376e-05, 5.0309e-03, 3.1607e-02, ..., 4.8286e-02,
4.3415e-02, -2.5420e-02],
...,
[ 1.2002e-02, -4.3146e-02, -4.4974e-03, ..., -2.2281e-02,
-4.0773e-02, -2.6248e-04],
[-5.5008e-02, -3.8148e-02, -3.4065e-02, ..., 3.3826e-02,
-2.6733e-02, 1.8512e-02],
[ 9.4306e-03, -3.9237e-02, -5.6302e-02, ..., -5.6227e-02,
-3.6273e-02, 2.9709e-02]])),
('output.bias',
tensor([ 0.0620, -0.0539, -0.0383, -0.0328, -0.0176, 0.0069, 0.0100, 0.0595,
-0.0195, 0.0577]))])
#查看优化器的state_dict
optimizer=torch.optim.SGD(net.parameters(),lr=0.001)
optimizer.state_dict()
{'param_groups': [{'dampening': 0,
'lr': 0.001,
'momentum': 0,
'nesterov': False,
'params': [499798207584, 499798206792, 499798380120, 499798379112],
'weight_decay': 0}],
'state': {}}
** 保存和加载state_dict(推荐方式)**
X=torch.rand(2,784)
Y=net(X)
PATH="C:/Users/mingming/net.pt"
torch.save(net.state_dict(),PATH) #保存
net2=MLP()
net2.load_state_dict(torch.load(PATH)) #加载模型
Y2=net2(X)
Y2==Y2 #检验模型参数是否一致
tensor([[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True]])
保存和加载整个模型
torch.save(net2,PATH)
model=torch.load(PATH)
print(model)
D:\Anaconda3\lib\site-packages\torch\serialization.py:360: UserWarning: Couldn't retrieve source code for container of type MLP. It won't be checked for correctness upon loading.
"type " + obj.__name__ + ". It won't be checked "
MLP(
(hidden): Linear(in_features=784, out_features=256, bias=True)
(act): ReLU()
(output): Linear(in_features=256, out_features=10, bias=True)
)