作为RNN神经网络的主力之一,GRU频频出现在各种网络结构中,因为其效率高于LSTM而得到青睐。如果遇到需要把GRU模块从tensorflow转到pytorch或者相反,那么,这篇文章将会帮助你迅速达成目标。
1 tensorflow GRU转pytorch
OK,你已经有了一个GRU模块,tensorflow下训练的,现在你想用训练好的GRU模型的参数在pytorch环境中来做推理。
tensorflow下GRU参数有三组:
kernel,recurrent_kernel,bias
如果你用get_weights接口去获取的话,你会得到包含上述3个元素的列表。
而pytorch的GRU则有weight_ih_l0,weight_hh_l0,bias_ih_l0,bias_hh_l0这4组参数。你需要把kernel,recurrent_kernel,bias转换到weight_ih_l0,weight_hh_l0,bias_ih_l0,bias_hh_l0上去。
1)kernel(tensorflow)转weight_ih_l0(pytorch):
r,z,h=np.vsplit(weights[0].T, 3) #weight[0]存放的是kernel权重参数
weight_ih_l0=np.concatenate([z,r,h],axis=0)
2)recurrent_kernel(tensorflow)转weight_hh_l0(pytorch):
r,z,h=np.vsplit(weights[1].T, 3) #weight[1]存放的是recurrent_kernel权重参数
weight_ih_l0=np.concatenate([z,r,h],axis=0)
3)bias(tensorflow)转bias_ih_l0,bias_hh_l0(pytorch):