pytorch模型参数C语言读取,pytorch | 模型相关(保存、载入、修改等)

模型以及参数是否能够合理地保存下来,是能否进行产品迭代的保证。

参考资料

保存模型参数

PyTorch 中保存模型的方式有许多种:

1

2

3

4

5

6

7

8# 保存整个网络,载入后可以直接用

torch.save(model, PATH)

# 保存网络中的参数, 速度快,占空间少

torch.save(model.state_dict(),PATH)

# 选择保存网络中的一部分参数或者额外保存其余的参数

torch.save({'state_dict': model.state_dict(), 'fc_dict':model.fc.state_dict(),

'optimizer': optimizer.state_dict(),'alpha': loss.alpha, 'gamma': loss.gamma},

PATH)

读取模型参数

同样的,PyTorch 中读取模型参数的方式也有许多种:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19# 读取整个网络

model = torch.load(PATH)

# 读取 Checkpoint 中的网络参数

model.load_state_dict(torch.load(PATH))

# 值得注意的是,这样加载的是保存的参数信息

# 也就是 PATH 这个文件对应的是 model.state_dict()

# 值得原理后,处理起来就很灵活,比如

# 即便保存的是全部网络信息,我们只需要调用 model.state_dict() 就好了

# 希望你能看懂我所讲的意思

# 若 Checkpoint 中的网络参数与当前网络参数有部分不同,有以下两种方式进行加载:

# 1. 利用字典的 update 方法进行加载

Checkpoint = torch.load(Path)

model_dict = model.state_dict()

model_dict.update(Checkpoint)

model.load_state_dict(model_dict)

# 2. 利用 load_state_dict() 的 strict 参数进行部分加载

model.load_state_dict(torch.load(PATH), strict=False)

如何利用前面的参数呢?我以 官方 的一段代码进行阐述,当然,利用参数有两种情况

自己创建的 model

别人创建的 model

自己的 model

以下代码参照

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37class Net(nn.Module):

def __init__(self):

super(Net, self).__init__()

self.conv1 = nn.Conv2d(3, 6, 5)

self.pool = nn.MaxPool2d(2, 2)

self.conv2 = nn.Conv2d(6, 16, 5)

self.fc1 = nn.Linear(16 * 5 * 5, 120)

self.fc2 = nn.Linear(120, 84)

self.fc3 = nn.Linear(84, 10)

def forward(self, x):

x = self.pool(F.relu(self.conv1(x)))

x = self.pool(F.relu(self.conv2(x)))

x = x.view(-1, 16 * 5 * 5)

x = F.relu(self.fc1(x))

x = F.relu(self.fc2(x))

x = self.fc3(x)

return x

# 经过一系列的计算,保存参数

PATH = './cifar_net.pth'

torch.save(net.state_dict(), PATH)

# 使用 参数

net = Net()

net.load_state_dict(torch.load(PATH))

outputs = net(images)

#dataiter = iter(testloader)

#images, labels = dataiter.next()

# 输出结果

_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]

for j in range(4)))

在这个代码中,我们可以看出,我们得预先知道整个网络的结构才行。

不使用 model,直接使用模型

关于这个代码,请大家看下面博文的第2个例子:一个简单的结构

别人的 model

微调网络结构

Pytorch 提供了许多 Pre-Trained Model on ImageNet,仅需调用 torchvision.models 即可,具体细节可查看官方文档。

往往我们需要对 Pre-Trained Model 进行相应的修改,以适应我们的任务。这种情况下,我们可以先输出 Pre-Trained Model 的结构,确定好对哪些层修改,或者添加哪些层,接着,再将其修改即可。

因为有预感在这个章节中可能有很多内容需要撰写更新,所以,决定开设新的博文,请移步到下面的博文。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值