导语
- 通过把经典分类器ResNet改成检测器来加深对检测器的理解
- 只是为了理解,所以这个简易检测器没有框回归,只是检测目标中心点
把ResNet改成检测器
- 这里使用ResNet18, 32倍下采样
- 改成了4倍下采样的检测器
- 上采样的设计有 反卷积和转置卷积, 这里使用反卷积
- ssh设计,最后的设计连接头
第一步分析每个Featuremap尺寸
- 去掉small size的输入,检测器开始,都是大输入, 不存在小输入
- 去掉pool, fc, 因为检测器需要的是FCN(全卷机网络)
- 这里写一个钩子来打印每层的结构
import torchvision.models as models
model = models.resnet18()
class hook:
def __init__(self, name):
self.name = name
def __call__(self, module, input, output):
print(output.shape, self.name)
for name, item in model._modules.items():
item.register_forward_hook(hook(name)) # forward_pre 是之前, forward是之后
_ = model(torch.zeros(1, 3, 128, 128))
output:
torch.Size([1, 64, 64, 64]) conv1
torch.Size([1, 64, 64, 64]) bn1
torch.Size([1, 64, 64, 64]) relu
torch.Size([1, 64, 32, 32]) maxpool
torch.Size([1, 64, 32, 32]) layer1
torch.Size([1, 128, 16, 16]) layer2
torch.Size([1, 256, 8, 8]) layer3
torch.Size([1, 512, 4, 4]) layer4
torch.Size([1, 512, 1, 1]) avgpool
torch.Size([1, 1000]) fc
- 通过分析可以知道上一层的尺寸是下一层的两倍,通道数是Layer2的一半,以此类推,设计上采样的时候scale = 2
上采样和投影
- 上采样是用来把下层的feature map扩到跟上层一样的尺寸 height x width
- 投影是用一个1x1的卷积 + BatchNormal 来改变通道数, 在深度学习中1x1的卷积通常有功能性作用,例如改变通道数或者改变尺寸
import torchvision.models as models
import torch
import torch.nn as nn
import torch.nn.functional as F
def upmodule(in_feature, out_feature, scale = 2):
# Upsample + conv + BN
return nn.Sequential(
nn.Upsample(scale_factor=scale, mode='nearest'),
nn.Conv2d(in_feature, out_feature, kernel_size=3, padding=1, stride=1, bias=False),
nn.BatchNorm2d(out_feature)
)
def projection_module(in_feature, out_feature):
# Conv + BN
return nn.Sequential(
nn.Conv2d(in_feature, out_feature, kernel_size=1, padding=0, stride=1, bias=False),
nn.BatchNorm2d(out_feature)
)
定义检测器
- 这里把Layer1之前的conv BN RELU 合并成了layer0,这里去掉了pool因为检测器要的是FCN
- 通道数统一改成24
class Detection(nn.Module):
def __init__(self):
super().__init__()
model = models.resnet18()
# 补全layer0
self.layer0 = nn.Sequential(
nn.Conv2d(4, 64, kernel_size=7, padding=3, stride=2, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
self.layer1 = model.layer1 # stride = 4 channel = 64
self.layer2 = model.layer2 # stride = 8 channel = 128
self.layer3 = model.layer3 # stride = 16 channel = 256
self.layer4 = model.layer4 # stride = 32 channel = 512
wide = 24
self.u4 = upmodule(512, wide)
self.p3 = projection_module(256, wide)
self.u3 = upmodule(wide, wide)
self.p2 = projection_module(128, wide)
self.u2 = upmodule(wide, wide)
self.p1 = projection_module(64, wide)
self.head = nn.Conv2d(wide, 1, kernel_size=1, padding=0, stride=1)
def forward(self, x):
x0 = self.layer0(x)
x1 = self.layer1(x0)
x2 = self.layer2(x1)
x3 = self.layer3(x2)
x4 = self.layer4(x3)
u4 = self.u4(x4)
p3 = self.p3(x3)
o4 = F.relu(u4 + p3) # 16倍下采样 channel = 24
u3 = self.u3(o4)
p2 = self.p2(x2)
o3 = F.relu(u3 + p2) # 8倍下采样 channel = 24
u2 = self.u2(o3)
p1 = self.p1(x1)
o2 = F.relu(u2 + p1) # 4倍下采样 channel = 24
return self.head(o2)
可以通过导出ONNX用Netron查看结构
torch.onnx.export(model, (torch.zeros(1, 3, 128, 128),), "myresnetdetection.onnx")
训练前的准备: 绘制GT
- 这个检测器是检测中心点的, 用高斯核绘制GT
def draw_gauss(image, x, y, gsize):
gsize += 1 - (gsize % 2)
sigma = gsize / 6
s = 2 * sigma * sigma
half = gsize // 2
ky, kx = np.ogrid[-half:+half+1, -half:+half+1]
kernel = np.exp(-(kx * kx + ky * ky) / s)
left, top = x - half, y - half
image[top:top+gsize, left:left+gsize] = kernel
return image
x, y = 78, 106 # 微信 alt + A
draw_gauss(gt_heatmap, x // stride, y // stride, 30 // stride)
训练前的准备: 制定Loss
- 使用改后的Focal Loss
- 中心点为1,其他地方都不选
def focal_loss(predict, target):
alpha = 2
beta = 4
positive_mask = target == 1
positive_loss = torch.pow((1 - predict), alpha) * torch.log(predict) * posative_mask
negative_loss = torch.pow((1 - torch_gt_heatmap), beta) * torch.pow(predict, alpha) * torch.log(1 - predict)
N = positive_mask.sum()
if N == 0:
N = 1
loss = -1 / N * (positive_loss + negative_loss).sum()
return loss
训练
调整好输入就可以看到训练的全过程啦
model = Detection()
glr = 1e-2
optim = torch.optim.Adagrad(model.parameters(), glr) # , momentum=0.9
iters = 150
loss_func = focal_loss
for niter in range(iters):
predict = model(input_image)
loss = loss_func(predict, torch_gt_heatmap)
optim.zero_grad()
loss.backward()
optim.step()
print(loss.item())
if niter % 10 == 0:
plt.imshow(predict[0, 0].data.numpy())
plt.show()