一:PyTorch里面的torch.nn.Parameter()
我看的是这篇博客的解释,有需要其实可以去官网看看(参考博客)
二:mmdetection/mmdet/models/utils/scale.py
import torch
import torch.nn as nn
class Scale(nn.Module):
def __init__(self, scale=1.0):
super(Scale, self).__init__()
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
def forward(self, x):
return x * self.scale
这里定义了一个Scale类,这个scale怎么用具体看网络时再看
mmdetection/mmdet/models/utils/conv_ws.py
import torch.nn as nn
import torch.nn.functional as F
def conv_ws_2d(input,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1,
eps=1e-5):
c_in = weight.size(0)
weight_flat = weight.view(c_in, -1)
mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1)
weight = (weight - mean) / (std + eps)
return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
class ConvWS2d(nn.Conv2d):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
eps=1e-5):
super(ConvWS2d, self).__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias)
self.eps = eps
def forward(self, x):
return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.groups, self.eps)
这里定义了一个ConvWS2d类,看代码主要就是把weights进行了一个归一化(减均值除标准差)处理 。具体使用后面再看。
建议:mmdetection/mmdet/models/utils 这个文件夹下定义了build_conv_layer, build_norm_layer这些函数,作用就是在构建网络时,从cfg配置项build相应的操作,主要包括几种conv和norm的方式(实现都是PyTorch自带的)。建议这个文件夹下的几个文件都看看。