[LibTorch] 指定参数不进行学习

出于不同的目的,希望固定预训练网络的某一些卷积层不进行参数的学习。

整个网络不进行参数学习

ResNet net = resnet50();
net->to(device);
torch::load(net, "../models/resnet50_caffe.pt");
for (const auto& p : net->parameters()) {
    p.requires_grad_(false);
}

某一个 BN 层不进行参数学习

/* 放在网络的构造函数中使用 */
register_module("conv1", conv1);
register_module("bn1", bn1);
for(const auto& p : bn1->parameters()) {
    p.requires_grad_(false);
}

网络中所有的 Conv 和 BN 层不进行参数学习

/* 放在网络的构造函数中使用 */
for(const auto& m : this->modules(/*include_self=*/false)) {
    if(auto* conv = m->as<torch::nn::Conv2d>()) {
    	/* 卷积层只有权重 */
    	conv->weight.requires_grad_(false);
    	/* 也可以这样写 */
		for(const auto& p : m->parameters()) {
    		p.requires_grad_(false);
		}
    }
    else if(auto* bn = m->as<torch::nn::BatchNorm2d>()) {
    	/* BN 层有权重和偏置 */
        bn->weight.requires_grad_(false);
        bn->bias.requires_grad_(false);
    	/* 也可以这样写 */
		for(const auto& p : m->parameters()) {
    		p.requires_grad_(false);
		}
    }
}

这种方法也适用于参数初始化

for(const auto& m : this->modules(/*include_self=*/false)) {
    if(auto* conv = m->as<torch::nn::Conv2d>()) {
        conv->weight.normal_(0, 0.01);
    }
    else if(auto* bn = m->as<torch::nn::BatchNorm2d>()) {
        bn->weight.fill_(1);
        bn->bias.zero_();
    }
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值