1、自定义网络和模型保存
import torch
import torch.nn as nn
class Testnet(nn.Module):
def __init__(self, in_channel, out_channel=32):
super(Testnet, self).__init__()
self.conv = nn.Conv2d(in_channel, out_channel, 3)
self.bn = nn.BatchNorm2d(out_channel)
self.act = nn.ReLU(inplace=True)
def forward(self, x):
out = self.conv(x)
out = self.bn(out)
out = self.act(out)
return out
model = Testnet(3, 32)
>>> print(model.state_dict().keys()
>>> odict_keys(['conv.weight', 'conv.bias', 'bn.weight', 'bn.bias', 'bn.running_mean', 'bn.running_var', 'bn.num_batches_tracked'])
# a、保存整个网络
torch.save(model, PATH1)
# b、保存网络中的参数, 速度快,占空间少
torch.save(model.state_dict(), PATH2)
# 加载模型
# a
model = torch.load(PATH1)
# b
model = Testnet(3, 32)
model.load_state_dict(torch.load(PATH2))