迁移学习-模型参数加载

1. 模型定义时name规则

  • 定义了变量名的,name=变量名;
  • 没有定义变量名的,使用Sequential()的,从0开始标号name。
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1=nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3)  # conv1.*
        self.layer1=nn.Sequential(
            nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3),      # layer1.0.*
            nn.BatchNorm2d(64),                                            # layer1.1.*
        )
        self.layer2=nn.Sequential(
            nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3),     # layer2.0.*

            nn.Conv2d(in_channels=128,out_channels=256,kernel_size=3),      # layer2.1.*

        )
        self.lay=self._make_layer()                                         #lay.0.* + lay.1.*
    def forward(self, input):
        out=self.conv1(input)
        out=self.lay(out)
        return out
    def _make_layer(self):

        self.layers = []
        self.layers.append(
            nn.Conv2d(12,12,4)
        )
        self.layers.append(nn.Conv2d(12, 12, 4))
        return nn.Sequential(*self.layers)

pars=Net().state_dict()

2. 训练参数迁移

# 你的模型
net=model()
# 训练好的模型参数读取
pre_dict=torch.load('c3d.pickle')  
# 你的模型参数, 即初始化参数
model_dict=net.state_dict() 
# 将pretrained_dict里不属于model_dict的键剔除掉 
pre_dict =  {name: value for name, value in pre_dict.items() if name in model_dict} 
# 更新现有的model_dict 
model_dict.update(pre_dict) 
# 加载我们真正需要的state_dict 
net.load_state_dict(model_dict) 

 

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值