验证导出的mindir是否有问题

可能会有小伙伴310推理的时候出现问题,但是找不出问题的具体原因。

此给出一个刨除mindir错误的方法:.mindir验证。

Export

一般我们在执行export.py脚本文件后会导出 mindir后缀的模型权重文件。

以模型RBPN的代码(输入有三个)如下:

#定义模型model = RBPN(num_channels=3, base_filter=256, feat=64, num_stages=3, n_resblock=5, nFrames=7,scale_factor=4)
#载入权重文件到模型params = load_checkpoint(args.ckpt)
load_param_into_net(model, params)

model.set_train(False)
#定义模型的输入维度input_array = Tensor(np.zeros([1, 3, 120, 180], np.float32))
neighbor_array = Tensor(np.zeros([1, 6, 3, 120, 180], np.float32))
flow_array = Tensor(np.zeros([1, 6, 2, 120, 180], np.float32))
input_all = [input_array, neighbor_array, flow_array]
#定义存储文件名和路径G_file = "{}_model".format(args.file_name)
mindir_path = 'mindir_path'
file_path = osp.join(os.getcwd(), mindir_path)
if not osp.exists(file_path):
    os.makedirs(file_path)
G_file_path = os.path.join(file_path, G_file)
#导出权重文件
export(model, *input_all, file_name=G_file_path, file_format=args.file_format)

执行上述代码后生成 .mindir 权重文件

验证mindir

可以仿照以下代码进行推理验证

#获取数据集
val_dataset = RBPNDatasetTest(args.val_path, args.nFrames, args.upscale_factor, args.file_list,                                           args.other_dataset, args.future_frame)
val_ds = create_val_dataset(val_dataset, args)
#创建迭代器
train_loader = val_ds.create_dict_iterator()
#加载mindir模型
graph = load("./model.mindir")
net = nn.GraphCell(graph)
net.to_float(mindspore.float16)
#推理for i, data in enumerate(train_loader, 1):
    #获取数据 
    input = data['input_image']
    neigbor = data['neigbor_image']
    target = data['target_image']
    input_all = [input, neigbor, flow]
    #输出结果
    output = net(*input_all)
    psnr = psnrNet(target , output)

如果发现得到的精度值和直接使用该模型的ckpt权重文件(eval脚本)进行推理的结果一致,那么恭喜,你导出的mindir文件是没有问题的,不用再纠结export是否有误。

注:记得在推理的时候,将shuffle设置为false,别因为数据顺序不一样而测试出不一样的答案。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值