【PyTorch】模型复现的学习笔记

模型对齐速查表

复现对齐步骤:

  • 预处理对齐
  • 模型权值对齐
  • 前向对齐(target对齐)
  • 反向对齐(loss对齐)
  • 后处理对齐(box对齐)

数值对齐标准:mean在{float32: 1e-6 | 1e-5 | 1e-3 | 1e-2}之内
对齐代码参考:
在这里插入图片描述
模型对齐流程:
一次前向反向对齐(1次迭代);
多次反向传播对齐(2次或5次或10次迭代loss);
输出差异排查:二分查找法

I. 超参数转写

在这一步,需要保证两份代码的参数配置是对齐的;

II. 张量对齐

需要对齐:

  1. 形状:tensor.shape
  2. 均值:tensor.abs().mean()

III. 模型权值转换

权重转换:

  1. 得到MMDet模型的Pytorch权重(.pth
  2. 生成参数名对应列表
  3. 根据列表将MMDet权重转换为native-torch模型权值

Paddle教程参考代码:

def load_torch_model(file_name):
    weights = torch.load(file_name)
    state_dict = weights['state dict']
    for k in state_dict.keys():
        state_dict[k] = state_dict[k].numpy()
    return state_dict

参考资料

Paddle冠冠老师权重对齐代码:Debug_gfl/gfl_debug_tools/weight_convert.sh

III. 前向运算对齐

主要原则

  • 输入张量一致
  • 测试输出张量足够接近(allclose

Note
在前向对齐的时候,可以将模型设置为eval模式,先验证target张量是否已经对齐。

IV. 反向传播对齐

对齐速查表

  1. 输入数据
  2. 权重参数
  3. 学习策略及超参数

对齐目标

对齐目标量:中间输出值和梯度值。

对齐步骤

  1. 训练几次迭代,对比paddle和torch的训练loss是否一致。若不一致,打印中间参数与梯度,输出并对比差异,定位差异点,并分析问题所在。
  2. 用相同数据集,学习策略,超参数分别训练oaddle和torch模型,对比训练log与可视化效果是否一致。

可以使用模拟数据(“假数据”)进行训练,看一下loss的变化情况;

使用序列化文件对齐的方法

关于使用序列化文件进行反向传播对齐的步骤,请参考《Paddle反向梯度排查方法文档 | 反向对齐核验》

获得梯度信息:可以使用register_hook()

我们可以使用register_hook()钩子函数获得梯度信息;

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值