torch::Device device = torch::Device(torch::kCPU)
int gpu_num = torch::getNumGPUs();
if (gpu_id >= 0) device = torch::Device(torch::kCUDA, gpu_id);
二维卷积
// 设置二维卷积函数的参数列表。
inline torch::nn::Conv2dOptions conv_options(int64_t in_planes, int64_t out_planes, int64_t kerner_size,
int64_t stride = 1, int64_t padding = 0, int groups = 1, bool with_bias = true, int dilation = 1) {
torch::nn::Conv2dOptions conv_options = torch::nn::Conv2dOptions(in_planes, out_planes, kerner_size);
conv_options.stride(stride);
conv_options.padding(padding);
conv_options.bias(with_bias);
conv_options.groups(groups);
conv_options.dilation(dilation);
return conv_options;
}
// 输入3通道,输出64通道,卷积核大小是7,步长是2,padding是3,groups是1,不设置bias.
conv1 = torch::nn::Conv2d(conv_options(3, 64, 7, 2, 3, 1, false));
上采样
upsample = torch::nn::Upsample(torch::nn::UpsampleOptions().scale_factor(std::vector<double>({2,2})).mode(torch::kNearest));
bachnorm
// 其中如果图像是64通道。
bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(64));
resnet的layer
conv1 = torch::nn::Conv2d(conv_options(3, 64, 7, 2, 3, 1, false));
bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(64));
layer1 = torch::nn::Sequential(_make_layer(64, layers[0])); // 步长是1
layer2 = torch::nn::Sequential(_make_layer(128, layers[1], 2)); // 步长是2
layer3 = torch::nn::Sequential(_make_layer(256, layers[2], 2)); // layers存放的是不同layer输出的通道数。
layer4 = torch::nn::Sequential(_make_layer(512, layers[3], 2));
fc = torch::nn::Linear(512 * expansion, num_classes);
register_module("conv1", conv1);
register_module("bn1", bn1);
register_module("layer1", layer1);
register_module("layer2", layer2);
register_module("layer3", layer3);
register_module("layer4", layer4);
register_module("fc", fc);
// planes: 每一个block输出的通道数。
// blocks: 每个layer包含的blocks数。
torch::nn::Sequential ResNetImpl::_make_layer(int64_t planes, int64_t blocks, int64_t stride) {
// 每个layer层可能会有一个下采样层(平面维度或者通道维度)。
torch::nn::Sequential downsample;
if (stride != 1 || inplanes != planes * expansion) { // 步长不等于1,则追加一个downsamle层。
downsample = torch::nn::Sequential(
torch::nn::Conv2d(conv_options(inplanes, planes * expansion, 1, stride, 0, 1, false)),
torch::nn::BatchNorm2d(planes * expansion)
);
}
//
torch::nn::Sequential layers;
layers->push_back(Block(inplanes, planes, stride, downsample, groups, base_width, is_basic)); // 第一个block.
inplanes = planes * expansion; // inplanes:记录输入通道数,初始值为64.
for (int64_t i = 1; i < blocks; i++) { // 剩余的blocks. 每个block的输入和输出通道都是固定的不变的,步长是1。
layers->push_back(Block(inplanes, planes, 1, torch::nn::Sequential(), groups, base_width,is_basic));
}
return layers;
}
resnet的block
两个conv+bn组成,卷积核大小都是3。
// is_basic: BasicBlock: 3x3 conv + 3x3 conv; BottleNeck 1x1 conv + 3x3 conv + 1x1 conv
BlockImpl::BlockImpl(int64_t inplanes, int64_t planes, int64_t stride_,
torch::nn::Sequential downsample_, int groups, int base_width, bool _is_basic)
{
downsample = downsample_;
stride = stride_;
int width = int(planes * (base_width / 64.)) * groups;
// BasicBlock: 3x3 conv + 3x3 conv
conv1 = torch::nn::Conv2d(conv_options(inplanes, width, 3, stride_, 1, groups, false));
bn1 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(width));
conv2 = torch::nn::Conv2d(conv_options(width, width, 3, 1, 1, groups, false));
bn2 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(width));
is_basic = _is_basic;
if (!is_basic) { // BottleNeck: 1x1 conv + 3x3 conv + 1x1 conv
conv1 = torch::nn::Conv2d(conv_options(inplanes, width, 1, 1, 0, 1, false));
conv2 = torch::nn::Conv2d(conv_options(width, width, 3, stride_, 1, groups, false));
conv3 = torch::nn::Conv2d(conv_options(width, planes * 4, 1, 1, 0, 1, false));
bn3 = torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(planes * 4));
}
register_module("conv1", conv1);
register_module("bn1", bn1);
register_module("conv2", conv2);
register_module("bn2", bn2);
if (!is_basic) { // 如果不是basic,则block里面再增加一个conv3
register_module("conv3", conv3);
register_module("bn3", bn3);
}
if (!downsample->is_empty()) {
register_module("downsample", downsample);
}
}
删除resnet最后的全链接层
torch::OrderedDict<std::string, at::Tensor> pretrained_dict = net_pretrained->named_parameters(); // 获取resnet网络每一层的key和value字典
torch::OrderedDict<std::string, at::Tensor> model_dict = this->named_parameters();
for (auto n = pretrained_dict.begin(); n != pretrained_dict.end(); n++)
{
if (strstr((*n).key().data(), "fc.")) { // key字符串中是否存在字串“fc.”
continue;
}
model_dict[(*n).key()] = (*n).value();
}