改网络,改结构,代码改来改去,发现一些基础功能都忘了。不打算探索“茴香豆的茴字有几种写法”,只研究见得最多的。其他的写法等到需要的时候,自然会研究,现在看完了也记不住。
torch.nn.Module搭建网络及相关探索
一、创建网络
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.f = torch.nn.Sequential(torch.nn.Linear(3, 10),
torch.nn.ReLU(),
torch.nn.BatchNorm1d(10),
torch.nn.Linear(10, 3))
def forward(self, input):
output = self.f(input)
return output
net = Net()
print(net)
输出
Net(
(f): Sequential(
(0): Linear(in_features=3, out_features=10, bias=True)
(1): ReLU()
(2): BatchNorm1d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): Linear(in_features=10, out_features=3, bias=True)
)
)
这么一看,直接用print就可以看清楚网络结构了
二、对网络参数的查看和修改
1.但是现在,出于某种原因,我就是想看看具体的参数
for name, para in net.named_parameters():
print(name, para)
输出
f.0.weight Parameter containing:
tensor([[ 0.5060, -0.4156, -0.4121],
[-0.0264, 0.1661, -0.4104],
[-0.1067, 0.4469, 0.3226],
[-0.1222, -0.5767, -0.4071],
[ 0.4389, -0.4107, 0.4619],
[-0.2293, -0.0214, 0.5376],
[ 0.3034, -0.0662, -0.2026],
[-0.2695, -0.1064, -0.0521],
[ 0.2155, 0.4808, -0.3267],
[-0.5386, -0.4093, 0.0512]], requires_grad=True)
f.0.bias Parameter containing:
tensor([ 0.5499, 0.1722, 0.2635, -0.4126, -0.3905, -0.3570, 0.2032, 0.2197,
-0.3590, 0.4718], requires_grad=True)
f.2.weight Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], requires_grad=True)
f.2.bias Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)
f.3.weight Parameter containing:
tensor([[ 0.2335, -0.0968, 0.0603, -0.2349, -0.1113, -0.0302, -0.2450, 0.2158,
-0.1215, 0.2469],
[-0.3107, -0.2186, 0.2061, -0.0518, 0.0633, 0.0064, 0.2820, 0.3025,
0.1572, 0.0027],
[ 0.2837, -0.2840, -0.0259, -0.0866, 0.0288, 0.2672, 0.2236, -0.2937,
-0.2335, 0.0893]], requires_grad=True)
f.3.bias Parameter containing:
tensor([-0.2099, -0.1148, 0.1129], requires_grad=True)
这种方式可以很清晰的看到每一个参数的值。值得注意的是,如f.0.weight中间的0代表在nn.Sequential中的第几个位置。如果想以自己的命名方式来代替0,在创建网络时,可以使用add_module()方式。详见参考
另外,每一个weight和每一个bias都是一个tensor(这用变量形容更合适),他们的requires_grad都默认为True
2.我现在想对其中的一个weight的值进行更改,比如对第一个全连接层的权重进行更改:
net.f[0].weight.data = torch.rand([10,3])
三、net.train()和net.eval()
打开调试模式,假设我们查看第一个全连接层,会发现有一项属性training。每个网络层都有它。net.train()是将net下的所有网络层的training,都设置成False;而net.eval()是设置成training。
为什么要设置?主要是针对BatchNorm,DropOut等网络层。当False时,参数就固定下来,也就是在测试的时候,而True是用来训练的。尽管其他网络层也有training,但是似乎没什么用。
详见官方文档
四、net.cuda()和net.cpu()
就是说将网络放在GPU上还是放在CPU上