参考博客:https://blog.csdn.net/weixin_44791964/article/details/106871010
博客里面是keras版本的 我根据网络结构写的Pytorch版本。
第一个文件是mobilev1 的 0.25版本的实现。
第二个文件上整体retinaface 网络结构实现。
1 mobilenet025.py
import torch
import torch.nn as nn
import numpy as np
from torchsummary import summary
import torch.nn.functional as F
class MBVBLOCK(nn.Module):
def __init__(self, in_c, out_c, s):
super().__init__()
self.mbv =nn.Sequential(
# Depthwise
nn.Conv2d(in_c, in_c, 3, s, padding=1, groups=in_c),
nn.BatchNorm2d(in_c),
nn.ReLU6(inplace=True),
# Pointwise
nn.Conv2d(in_c, out_c, 1, 1, padding=0, groups=1),
nn.BatchNorm2d(out_c),
nn.ReLU6(inplace=True)
)
def forward(self, x):
x = self.mbv(x)
print(x.shape)
return x
class MobileNet025(nn.Module):
def __init__(self):
super().__init__()
self.pre = nn.Sequential(
nn.Conv2d(3, 8, 3, 2, padding=1),
nn.BatchNorm2d(8),
nn.ReLU6(inplace=True)
)
self.feat1 = nn.Sequential(
MBVBLOCK(8, 16, 2),
MBVBLOCK(16, 32, 2),
MBVBLOCK(32, 32, 1),
MBVBLOCK(32, 64, 2),
MBVBLOCK(64, 64, 1),
)
self.feat2 = nn.Sequential(
MBVBLOCK(64, 128, 2),
MBVBLOCK(128, 128, 1),
MBVBLOCK(128, 128, 1),
MBVBLOCK(128, 128, 1),
MBVBLOCK(128, 128, 1),
MBVBLOCK(128, 128, 1),
)
self.feat3 = nn.Sequential(
MBVBLOCK(128, 256, 2),
MBVBLOCK(256, 256, 1),
)
def forward(self, x):
x = self.pre(x)
# print(x.shape)
f1 = self.feat1(x)
f2 = self.feat2(f1)
f3 = self.feat3(f2)
return [f1, f2, f3]
if __name__ == '__main__':
net = MobileNet025()
x = torch.randn(1,3,256,256)
y = net(x)
print(y[0].shape, y[1].shape, y[2].shape)
summary(net,(3, 256, 256))
# 224 112 64 32 16 8 4 2
2 retinaface.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from nets.mobilenet025 import MobileNet025
class ConcBatchLRelu(nn.Module):
def __init__(self,ic, oc, k, s, p):
super().__init__()
self.f = nn.Sequential(
nn.Conv2d(ic, oc, k, s, p),
nn.BatchNorm2d(oc),
nn.LeakyReLU(0.01, inplace=True)
)
def forward(self, x):
return self.f(x)
class PyramidFeat(nn.Module):
def __init__(self):
super().__init__()
self.c1 = ConcBatchLRelu(256, 256, 1, 1, 0)
self.f1 = ConcBatchLRelu(256, 256, 3, 1, 1)
self.x1 = ConcBatchLRelu(256, 128, 1, 1, 0)
self.c2 = ConcBatchLRelu(128, 128, 1, 1, 0)
self.f2 = ConcBatchLRelu(128, 128, 3, 1, 1)
self.x2 = ConcBatchLRelu(128, 64, 1, 1, 0)
self.c3 = ConcBatchLRelu(64, 64, 1, 1, 0)
self.f3 = ConcBatchLRelu(64, 64, 3, 1, 1)
def forward(self, x):
x1 = self.c1(x[2])
print(x1.shape)
x2 = self.c2(x[1])
print(x2.shape)
x3 = self.c3(x[0])
print(x3.shape)
y1 = self.f1(x1)
print('y1', y1.shape)
x1 = F.interpolate(x1, scale_factor=2, mode='nearest')
x1 =self.x1(x1)
print(x1.shape)
print('x1', x1.shape)
y2 = self.f2(x2+x1)
x2 = self.x2(x2+x1)
x1 = F.interpolate(x2, scale_factor=2, mode='nearest')
y3 = self.f3(x3+x1)
return [y1, y2, y3]
class SSH(nn.Module):
def __init__(self, ic):
super().__init__()
self.conv1 = ConcBatchLRelu(ic, 32, 3, 1, 1)
self.conv2 = ConcBatchLRelu(ic, 16, 3, 1, 1)
self.conv2_1 = ConcBatchLRelu(16, 16, 3, 1, 1)
self.conv3_1 = ConcBatchLRelu(16, 16, 3, 1, 1)
self.conv3_2 = ConcBatchLRelu(16, 16, 3, 1, 1)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2_1(self.conv2(x))
x3 = self.conv3_2(self.conv3_1(self.conv2(x)))
# print()
# print(x1.shape,x2.shape,x3.shape)
y = torch.cat((x1, x2, x3),dim=1)
return y
class Head(nn.Module):
def __init__(self,num_anchors = 2, in_channel = 64):
super().__init__()
self.num_anchors = num_anchors
self.ClassHead = nn.Conv2d(in_channel, self.num_anchors * 2, 1, 1, 0)
self.bboxHead = nn.Conv2d(in_channel, self.num_anchors * 4, 1, 1, 0)
self.landmarkHead = nn.Conv2d(in_channel, self.num_anchors * 5 * 2, 1, 1, 0)
def forward(self, x):
y1 = self.ClassHead(x).view(-1, 2)
y1 = F.softmax(y1, dim=1)
y2 = self.bboxHead(x).view(-1, 4)
y3 = self.landmarkHead(x).view(-1, 10)
return [y1, y2, y3]
class RetinafaceNet():
# def __init__(self, backone = 'mobilenet'):
# return
# self.mnet = MobileNet025()
# self.Pyra = PyramidFeat()
# self.sshfeat = SSH()
# self.out = Head()
def forward(self,x):
mnet = MobileNet025()
Pyra = PyramidFeat()
mnet_y = mnet(x)
# print('mnet', mnet_y[0].shape, mnet_y[1].shape, mnet_y[2].shape)
# mnet torch.Size([1, 64, 16, 16]) torch.Size([1, 128, 8, 8]) torch.Size([1, 256, 4, 4])
Pfeat = Pyra(mnet_y)
s_channel1 = Pfeat[0].shape[1]
s_channel2 = Pfeat[1].shape[1]
s_channel3 = Pfeat[2].shape[1]
print('s_channel1', s_channel1,s_channel2,s_channel3)
ssh1 = SSH(s_channel1)
ssh2 = SSH(s_channel2)
ssh3 = SSH(s_channel3)
s1_feat = ssh1(Pfeat[0])
s2_feat = ssh2(Pfeat[1])
s3_feat = ssh3(Pfeat[2])
# s1_channel = s1_feat.shape[1]
# s2_channel = s2_feat.shape[1]
# s3_channel = s3_feat.shape[1]
# print('s1_channel',s1_channel,s2_channel,s3_channel)
head1 = Head()
head2 = Head()
head3 = Head()
# print('s1_feat',s1_feat.shape, s2_feat.shape, s3_feat.shape)
# s1_feat torch.Size([1, 64, 4, 4]) torch.Size([1, 64, 8, 8]) torch.Size([1, 64, 16, 16])
f1out = head1(s1_feat)
f2out = head2(s2_feat)
f3out = head3(s3_feat)
print('f1', f1out[0].shape, f1out[1].shape, f1out[2].shape)
print('f2', f2out[0].shape, f2out[1].shape, f2out[2].shape)
print('f3', f3out[0].shape, f3out[1].shape, f3out[2].shape)
output1 = torch.cat((f1out[0], f2out[0], f3out[0]), dim=0)
output2 = torch.cat((f1out[1], f2out[1], f3out[1]), dim=0)
output3 = torch.cat((f1out[2], f2out[2], f3out[2]), dim=0)
output = [output1, output2, output3]
return output
if __name__ == '__main__':
# x1 = torch.randn(1, 64, 16, 16)
# x2 = torch.randn(1, 128, 8, 8)
# x3 = torch.randn(1, 256, 4, 4)
#
# x = [x3, x2, x1]
# net = PyramidFeat()
# y = net(x)
# print(y[0].shape, y[1].shape, y[2].shape)
#
# model = SSH(3)
# x_ = torch.randn(1, 3, 24, 24)
# y =model(x_)
# print(y.shape)
# --------------------------- #
# torch.Size([1, 32, 24, 24])
# torch.Size([1, 16, 24, 24])
# torch.Size([1, 16, 24, 24])
# torch.Size([1, 64, 24, 24])
# --------------------------- #
x = torch.randn(1, 3, 256, 256)
model = RetinafaceNet()
print(model)
y = RetinafaceNet().forward(x)
# print(y)