Weights_Keras_2_Pytorch
最近想在Pytorch项目里使用一下谷歌的NIMA,但是发现没有预训练好的pytorch权重,于是整理了一下将Keras预训练权重转为Pytorch的代码,目前是支持Keras的Conv2D, Dense, DepthwiseConv2D, BatchNormlization的转换。需要注意的是在Pytorch模型内需要给每一层命名为与Keras每一层相同的名字,才能对应转换。
代码地址:
https://github.com/AgCl-LHY/Weights_Keras_2_Pytorch
核心代码:
def keras_to_pyt(km, pm):
weight_dict = dict()
for layer in km.layers:
if type(layer) is keras.layers.convolutional.Conv2D:
if (len(layer.get_weights()) >= 1):
weight_dict[layer.get_config()['name'] + '.weight'] = np.transpose(layer.get_weights()[0], (3, 2, 0, 1))
if (len(layer.get_weights())