Pytorch Retinaface 模型实现

参考博客:https://blog.csdn.net/weixin_44791964/article/details/106871010

博客里面是keras版本的 我根据网络结构写的Pytorch版本。

第一个文件是mobilev1 的 0.25版本的实现。

第二个文件上整体retinaface 网络结构实现。

1 mobilenet025.py

import torch
import torch.nn as nn
import numpy as np
from torchsummary import summary
import torch.nn.functional as F

class MBVBLOCK(nn.Module):
    def __init__(self, in_c, out_c, s):
        super().__init__()

        self.mbv =nn.Sequential(

            # Depthwise
            nn.Conv2d(in_c, in_c, 3, s, padding=1, groups=in_c),
            nn.BatchNorm2d(in_c),
            nn.ReLU6(inplace=True),

            # Pointwise
            nn.Conv2d(in_c, out_c, 1, 1, padding=0, groups=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU6(inplace=True)
        )

    def forward(self, x):
        x = self.mbv(x)
        print(x.shape)
        return x

class MobileNet025(nn.Module):
    def __init__(self):
        super().__init__()
        self.pre = nn.Sequential(
            nn.Conv2d(3, 8, 3, 2, padding=1),
            nn.BatchNorm2d(8),
            nn.ReLU6(inplace=True)
        )

        self.feat1 = nn.Sequential(
            MBVBLOCK(8, 16, 2),
            MBVBLOCK(16, 32, 2),
            MBVBLOCK(32, 32, 1),
            MBVBLOCK(32, 64, 2),
            MBVBLOCK(64, 64, 1),
        )
        self.feat2 = nn.Sequential(
            MBVBLOCK(64, 128, 2),
            MBVBLOCK(128, 128, 1),
            MBVBLOCK(128, 128, 1),
            MBVBLOCK(128, 128, 1),
            MBVBLOCK(128, 128, 1),
            MBVBLOCK(128, 128, 1),
        )
        self.feat3 = nn.Sequential(
            MBVBLOCK(128, 256, 2),
            MBVBLOCK(256, 256, 1),
        )

    def forward(self, x):
        x = self.pre(x)
        # print(x.shape)
        f1 = self.feat1(x)
        f2 = self.feat2(f1)
        f3 = self.feat3(f2)
        return [f1, f2, f3]

if __name__ == '__main__':
    net = MobileNet025()
    x = torch.randn(1,3,256,256)
    y = net(x)
    print(y[0].shape, y[1].shape, y[2].shape)
    summary(net,(3, 256, 256))

    # 224 112 64 32 16 8 4 2



 2 retinaface.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from nets.mobilenet025 import MobileNet025

class ConcBatchLRelu(nn.Module):
    def __init__(self,ic, oc, k, s, p):
        super().__init__()

        self.f = nn.Sequential(
            nn.Conv2d(ic, oc, k, s, p),
            nn.BatchNorm2d(oc),
            nn.LeakyReLU(0.01, inplace=True)
        )
    def forward(self, x):
        return self.f(x)

class PyramidFeat(nn.Module):
    def __init__(self):
        super().__init__()

        self.c1 = ConcBatchLRelu(256, 256, 1, 1, 0)
        self.f1 = ConcBatchLRelu(256, 256, 3, 1, 1)
        self.x1 = ConcBatchLRelu(256, 128, 1, 1, 0)

        self.c2 = ConcBatchLRelu(128, 128, 1, 1, 0)
        self.f2 = ConcBatchLRelu(128, 128, 3, 1, 1)
        self.x2 = ConcBatchLRelu(128, 64, 1, 1, 0)

        self.c3 = ConcBatchLRelu(64, 64, 1, 1, 0)
        self.f3 = ConcBatchLRelu(64, 64, 3, 1, 1)

    def forward(self, x):
        x1 = self.c1(x[2])
        print(x1.shape)
        x2 = self.c2(x[1])
        print(x2.shape)
        x3 = self.c3(x[0])
        print(x3.shape)

        y1 = self.f1(x1)
        print('y1', y1.shape)
        x1 = F.interpolate(x1, scale_factor=2, mode='nearest')
        x1 =self.x1(x1)
        print(x1.shape)
        print('x1', x1.shape)
        y2 = self.f2(x2+x1)

        x2 = self.x2(x2+x1)
        x1 = F.interpolate(x2, scale_factor=2, mode='nearest')

        y3 = self.f3(x3+x1)
        return [y1, y2, y3]

class SSH(nn.Module):
    def __init__(self, ic):
        super().__init__()
        self.conv1 = ConcBatchLRelu(ic, 32, 3, 1, 1)

        self.conv2 = ConcBatchLRelu(ic, 16, 3, 1, 1)
        self.conv2_1 = ConcBatchLRelu(16, 16, 3, 1, 1)

        self.conv3_1 = ConcBatchLRelu(16, 16, 3, 1, 1)
        self.conv3_2 = ConcBatchLRelu(16, 16, 3, 1, 1)

    def forward(self, x):

        x1 = self.conv1(x)
        x2 = self.conv2_1(self.conv2(x))
        x3 = self.conv3_2(self.conv3_1(self.conv2(x)))
        # print()
        # print(x1.shape,x2.shape,x3.shape)

        y = torch.cat((x1, x2, x3),dim=1)
        return y

class Head(nn.Module):
    def __init__(self,num_anchors = 2, in_channel = 64):
        super().__init__()
        self.num_anchors = num_anchors
        self.ClassHead = nn.Conv2d(in_channel, self.num_anchors * 2, 1, 1, 0)
        self.bboxHead = nn.Conv2d(in_channel, self.num_anchors * 4, 1, 1, 0)
        self.landmarkHead = nn.Conv2d(in_channel, self.num_anchors * 5 * 2, 1, 1, 0)
    def forward(self, x):
        y1 = self.ClassHead(x).view(-1, 2)
        y1 = F.softmax(y1, dim=1)
        y2 = self.bboxHead(x).view(-1, 4)
        y3 = self.landmarkHead(x).view(-1, 10)
        return [y1, y2, y3]

class RetinafaceNet():
    # def __init__(self, backone = 'mobilenet'):
    #     return
        # self.mnet = MobileNet025()
        # self.Pyra = PyramidFeat()
        # self.sshfeat = SSH()
        # self.out = Head()
    def forward(self,x):
        mnet = MobileNet025()
        Pyra = PyramidFeat()
        mnet_y = mnet(x)
        # print('mnet', mnet_y[0].shape, mnet_y[1].shape, mnet_y[2].shape)
        # mnet torch.Size([1, 64, 16, 16]) torch.Size([1, 128, 8, 8]) torch.Size([1, 256, 4, 4])
        Pfeat =  Pyra(mnet_y)
        s_channel1 = Pfeat[0].shape[1]
        s_channel2 = Pfeat[1].shape[1]
        s_channel3 = Pfeat[2].shape[1]
        print('s_channel1', s_channel1,s_channel2,s_channel3)
        ssh1 = SSH(s_channel1)
        ssh2 = SSH(s_channel2)
        ssh3 = SSH(s_channel3)
        s1_feat = ssh1(Pfeat[0])
        s2_feat = ssh2(Pfeat[1])
        s3_feat = ssh3(Pfeat[2])
        # s1_channel = s1_feat.shape[1]
        # s2_channel = s2_feat.shape[1]
        # s3_channel = s3_feat.shape[1]
        # print('s1_channel',s1_channel,s2_channel,s3_channel)
        head1 = Head()
        head2 = Head()
        head3 = Head()
        # print('s1_feat',s1_feat.shape, s2_feat.shape, s3_feat.shape)
        # s1_feat torch.Size([1, 64, 4, 4]) torch.Size([1, 64, 8, 8]) torch.Size([1, 64, 16, 16])
        f1out = head1(s1_feat)
        f2out = head2(s2_feat)
        f3out = head3(s3_feat)
        print('f1', f1out[0].shape, f1out[1].shape, f1out[2].shape)
        print('f2', f2out[0].shape, f2out[1].shape, f2out[2].shape)
        print('f3', f3out[0].shape, f3out[1].shape, f3out[2].shape)
        output1 = torch.cat((f1out[0], f2out[0], f3out[0]), dim=0)
        output2 = torch.cat((f1out[1], f2out[1], f3out[1]), dim=0)
        output3 = torch.cat((f1out[2], f2out[2], f3out[2]), dim=0)
        output = [output1, output2, output3]
        return output




if __name__ == '__main__':
    # x1 = torch.randn(1, 64, 16, 16)
    # x2 = torch.randn(1, 128, 8, 8)
    # x3 = torch.randn(1, 256, 4, 4)
    #
    # x = [x3, x2, x1]
    # net = PyramidFeat()
    # y = net(x)
    # print(y[0].shape, y[1].shape, y[2].shape)
    #
    # model = SSH(3)
    # x_ = torch.randn(1, 3, 24, 24)
    # y =model(x_)
    # print(y.shape)
    # --------------------------- #
    # torch.Size([1, 32, 24, 24])
    # torch.Size([1, 16, 24, 24])
    # torch.Size([1, 16, 24, 24])
    # torch.Size([1, 64, 24, 24])
    # --------------------------- #
    x = torch.randn(1, 3, 256, 256)
    model = RetinafaceNet()
    print(model)
    y = RetinafaceNet().forward(x)
    # print(y)




 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值