Tensorflow 中保存模型后加载后模型报维度不匹配错误

Tensorflow 中保存模型后加载后模型报维度不匹配错误

  you must feed  tensor  with [?,784]

  最后发现是占位符的名称错误

  重新使用加载模型的占位符后错误修改完成

logits, train_op, eval_correct, loss, data_placeholder, labels_placeholder = mnist_cnn_model(2500)
feed_dict_lo = {str(data_placeholder.name): train_data,str(labels_placeholder.name):train_label}
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
确实,对于剪枝后的模型与原始模型,它们的结构已经发生了变化,因此在加载剪枝后的模型时,需要考虑多个因素,如何正确地加载模型参数和 mask,以及如何将 mask 应用到剪枝后的模型上,等等。 不过,如果你使用 PyTorch 自带的剪枝工具,那么在剪枝模型时,PyTorch 会自动为模型保存一个 mask,用于记录哪些参数被剪枝掉了。在加载模型时,你只需要同时加载模型参数和 mask,并将 mask 应用到剪枝后的模型上即可。以下是一个示例代码,展示了如何加载 PyTorch 剪枝后的模型: ```python import torch import torch.nn as nn import torch.nn.utils.prune as prune # 定义一个简单的模型 class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() self.fc1 = nn.Linear(10, 5) self.fc2 = nn.Linear(5, 2) def forward(self, x): x = self.fc1(x) x = torch.relu(x) x = self.fc2(x) return x # 加载剪枝前模型的定义 model = MyModel() # 加载剪枝后模型的参数和 mask model.load_state_dict(torch.load('pruned_model.pt', map_location='cpu')) prune_state_dict = torch.load('pruned_model.pt', map_location='cpu') # 应用 mask 到剪枝后的模型上 for name, module in model.named_modules(): if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): prune.CustomFromMask.apply(module, 'weight', prune_state_dict[name + '.weight_mask']) # 将模型加载到 CUDA 设备上 device = torch.device('cuda:0') model.to(device) # 对输入数据进行计算 x = torch.randn(3, 10).to(device) output = model(x) # 检查输出数据的维度 print(output.size()) # 输出为 torch.Size([3, 2]) ``` 在这个示例代码,我们加载剪枝前模型的定义,并使用 `load_state_dict()` 方法加载剪枝后的模型参数和 mask。然后,我们遍历剪枝后的模型的每个模块,如果模块是一个卷积层或全连接层,就从 mask 加载相应的 mask,并将其应用到剪枝后的模型上。最后,我们将模型加载到 CUDA 设备上,并使用输入数据对模型进行计算。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值