pytorch 换版本_Pytorch 模型版本切换

0.3.1转到0.4.1或更高版本

直接使用代码导入时常碰到 ‘BatchNorm2d’ object has no attribute ‘track_running_stats’的报错信息,这是由于0.3.1中的BN操作中没有配置track_running_stats参数,0.3.1中BatchNorm的定义如下

class torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True)

$$ y = frac{x - mean[x]}{ sqrt{Var[x] + epsilon}} * gamma + beta$$Parameters:

num_features – num_features from an expected input of size batch_size x num_features x height x width

eps – a value added to the denominator for numerical stability. Default: 1e-5

momentum – the value used for the running_mean and running_var computation. Default: 0.1

affine – a boolean value that when set to True, gives the layer learnable affine parameters. Default: TrueShape:

Input: (N,C,H,W)

Output: (N,C,H,W) (same shape as input)

而在0.4.1中定义发生了变化

class torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

If track_running_stats is set to False, this layer then does not keep running estimates,

and batch statistics are instead used during evaluation time as well.

所以使用0.4.1或以上版本导入0.3.1模型时需要对模型中的BN层添加track_running_stats参数,代码如下1

2

3

4

5

6

7

8

9

10

11

12def (module):

if isinstance(module, torch.nn.BatchNorm2d):

module.track_running_stats = True

else:

for name, module1 in module._modules.items():

module1 = recursion_change_bn(module1)

check_point = torch.load(check_point_file_path)

model = check_point['net']

for name, module in model._modules.items():

recursion_change_bn(model)

model.eval()

另外,也可以在导入模型处直接修改模型,模型的statedict本身可以理解为一个Orderdict,在模型中添加参数num_batches_tracked对应的值即可. 具体做法是在键值为running_var后添加一个键值为num_batches_tracked,值为0的Tensor. 具体代码如下1

2

3

4

5

6

7

8

9checkpoint = torch.load(checkpoint_path, map_location=device)

mapped_state_dict = OrderedDict()

for key, value in checkpoint['state_dict'].items():

print(key)

mapped_key = key

mapped_state_dict[mapped_key] = value

if 'running_var' in key:

mapped_state_dict[key.replace('running_var', 'num_batches_tracked')] = torch.zeros(1).to(device)

model.load_state_dict(mapped_state_dict)

0.3.1版本导入0.4.1以上版本模型0.4中使用设备:.to(device)

0.4中删除了Variable,直接tensor就可以

with torch.no_grad():的使用代替volatile;弃用volatile,测试中不需要计算梯度的话,用with torch.no_grad():

data改用.detach;x.detach()返回一个requires_grad=False的共享数据的Tensor,并且,如果反向传播中需要x,那么x.detach返回的Tensor的变动会被autograd追踪。相反,x.data()返回的Tensor,其变动不会被autograd追踪,如果反向传播需要用到x的话,值就不对了。

pytorch0.4有一些接口已经改变,且模型向下版本兼容,不向上兼容。

使用pytorch0.3导入pytorch0.4保存的模型时候在导入前添加如下代码段,解决的报错内容为(AttributeError: Can’t get attribute ‘_rebuild_tensor_v2’ on \lib\site-packages\torch\_utils.py'>),详情可对比查看_utils.py文件:1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28# See https://discuss.pytorch.org/t/question-about-rebuild-tensor-v2/14560

import torch

# ***********pytorch0.3.1导入0.4.1以上版本模型时加入以下代码块**********

# 使用以下函数代替torch._utils中的函数(0.3.1中可能不存在或者接口不同导致的报错)

try:

torch._utils._rebuild_tensor_v2

except AttributeError:

def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):

tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)

tensor.requires_grad = requires_grad

tensor._backward_hooks = backward_hooks

return tensor

torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2

try:

torch._utils._rebuild_parameter

except AttributeError:

def _rebuild_parameter(data, requires_grad, backward_hooks):

param = torch.nn.Parameter(data, requires_grad)

# NB: This line exists only for backwards compatibility; the

# general expectation is that backward_hooks is an empty

# OrderedDict. See Note [Don't serialize hooks]

param._backward_hooks = backward_hooks

return param

torch._utils._rebuild_parameter = _rebuild_parameter

# ***********************************************************************

在导出为ONNX模型时还可能会报错存在多余的num_batches_tracked值, 错误代码为KeyError: 'unexpected key "module.bn1.num_batches_tracked" in state_dict', 此处的处理方式和上边的添加num_batches_tracked键值对应,删除该键值即可,具体代码如下1

2

3

4

5

6

7

8

9checkpoint = torch.load(checkpoint_path, map_location=device)

mapped_state_dict = OrderedDict()

for key, value in checkpoint['state_dict'].items():

print(key)

mapped_key = key

mapped_state_dict[mapped_key] = value

if 'num_batches_tracked' in key:

del mapped_state_dict[key]

model.load_state_dict(mapped_state_dict)由0.4.1导出为0.3.1的ONNX模型时,上述两段代码都需要加入

导出为1.0.0模型

pytorch1.0.0添加了torch.jit, 可以直接将模型和网络打包到模型文件中,而不需要在使用模型文件时导入网络定义,在模型的使用时变得更加方便了

模型的jit导出1

2

3

4

5def pth_to_jit(model, save_path, device="cuda:0"):

model.eval()

input_x = torch.randn(1, 3, 144, 144).to(device) # 输入大小

new_model = torch.jit.trace(model, input_x)

torch.jit.save(new_model, save_path)

jit模型导入使用1

2

3def load_jit(jit_model_path):

model = torch.jit.load(jit_model_path, map_location=torch.device('cuda:0'))

model.eval()

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
如果你要在 PyTorch 切换数据集加载模型,你需要修改数据加载器的代码以适应新的数据集。具体来说,你需要更新数据集的路径、图像大小、批量大小等参数。同时,你还需要确保数据集的格式与你的模型训练代码的预期格式相同。 以下是一个简单的代码示例,以 MNIST 数据集为例: ```python import torch import torchvision import torchvision.transforms as transforms # 定义数据集路径 train_dataset_path = '/path/to/new/train/dataset' test_dataset_path = '/path/to/new/test/dataset' # 定义图像转 transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) # 加载训练集 trainset = torchvision.datasets.MNIST(root=train_dataset_path, train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2) # 加载测试集 testset = torchvision.datasets.MNIST(root=test_dataset_path, train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2) ``` 在这个示例,我们首先定义了新数据集的路径,然后使用 PyTorch 的内置 MNIST 数据集函数来加载数据。我们还定义了图像转,以确保每个图像都具有相同的大小和格式。最后,我们使用 PyTorch 的 DataLoader 类来创建训练集和测试集的加载器,以便我们可以在模型训练代码使用它们。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值