玩转pytorch和tensorflow之(4)——GRU

        作为RNN神经网络的主力之一,GRU频频出现在各种网络结构中,因为其效率高于LSTM而得到青睐。如果遇到需要把GRU模块从tensorflow转到pytorch或者相反,那么,这篇文章将会帮助你迅速达成目标。

1 tensorflow GRU转pytorch

        OK,你已经有了一个GRU模块,tensorflow下训练的,现在你想用训练好的GRU模型的参数在pytorch环境中来做推理。

        tensorflow下GRU参数有三组:

                kernel,recurrent_kernel,bias

        如果你用get_weights接口去获取的话,你会得到包含上述3个元素的列表。

        而pytorch的GRU则有weight_ih_l0,weight_hh_l0,bias_ih_l0,bias_hh_l0这4组参数。你需要把kernel,recurrent_kernel,bias转换到weight_ih_l0,weight_hh_l0,bias_ih_l0,bias_hh_l0上去。

        1)kernel(tensorflow)转weight_ih_l0(pytorch):

        r,z,h=np.vsplit(weights[0].T, 3)                #weight[0]存放的是kernel权重参数
        weight_ih_l0=np.concatenate([z,r,h],axis=0)

        2)recurrent_kernel(tensorflow)转weight_hh_l0(pytorch):

        r,z,h=np.vsplit(weights[1].T, 3)                #weight[1]存放的是recurrent_kernel权重参数
        weight_ih_l0=np.concatenate([z,r,h],axis=0)

        3)bias(tensorflow)转bias_ih_l0,bias_hh_l0(pytorch):

      

  • 12
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

geastwind1

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值