有时候任务需要,想从一个训练好的网络里提取部分网络和参数做为自己的网络,本文将教你如何用pytorch实现。
首先看一下训练好的网络结构:
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/ba7874183a7b2fa4545b7af11bf1b59e.png)
这是一个seq2seq网络,包含encoder和decoder两部分,每一部分都包含一个embedding层、一个LSTM层和一个Dropout层,decoder网络还有一个Linear层。
然后看一下新的网络结构:
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/865f246be154992deb3a2669085b5020.png)
同样是一个seq2seq的结构,区别是decoder网络里面只保留了一个Linear层,其他层都删掉了。那么怎么将训练好的网络参数复制过来呢?
1、保存预训练好的网络参数
torch.save(model.state_dict()