出于不同的目的,希望固定预训练网络的某一些卷积层不进行参数的学习。
整个网络不进行参数学习
ResNet net = resnet50();
net->to(device);
torch::load(net, "../models/resnet50_caffe.pt");
for (const auto& p : net->parameters()) {
p.requires_grad_(false);
}
某一个 BN 层不进行参数学习
/* 放在网络的构造函数中使用 */
register_module("conv1", conv1);
register_module("bn1", bn1);
for(const auto& p : bn1->parameters()) {
p.requires_grad_(false);
}
网络中所有的 Conv 和 BN 层不进行参数学习
/* 放在网络的构造函数中使用 */
for(const auto& m : this->modules(/*include_self=*/false)) {
if(auto* conv = m->as<torch::nn::Conv2d>()) {
/* 卷积层只有权重 */
conv->weight.requires_grad_(false);
/* 也可以这样写 */
for(const auto& p : m->parameters()) {
p.requires_grad_(false);
}
}
else if(auto* bn = m->as<torch::nn::BatchNorm2d>()) {
/* BN 层有权重和偏置 */
bn->weight.requires_grad_(false);
bn->bias.requires_grad_(false);
/* 也可以这样写 */
for(const auto& p : m->parameters()) {
p.requires_grad_(false);
}
}
}
这种方法也适用于参数初始化
for(const auto& m : this->modules(/*include_self=*/false)) {
if(auto* conv = m->as<torch::nn::Conv2d>()) {
conv->weight.normal_(0, 0.01);
}
else if(auto* bn = m->as<torch::nn::BatchNorm2d>()) {
bn->weight.fill_(1);
bn->bias.zero_();
}
}