**mtcnn 基础网络架构**
```python
import torch
import torch.nn as nn
from collections import OrderedDict
import numpy as np
class Flatten(nn.Module):
def __init__(self):
super(Flatten,self).__init__()
def forward(self, x):
'''
:param : 输入 batchsize 通道 高 宽
:return: 输出 batchsize 后三者的乘积
'''
x = x.transpase(3,2).contigous()
return x.view(x.size(0),-1)
class Pnet(nn.Module):
def __init__(self):
super(Pnet,self).__init__()
self.feature = nn.Sequential(OrderedDict[
('conv1',nn.Conv2d(3,10,3,1)),
('prelu1',nn.PReLU(10)),
('pool1',nn.MaxPool2d(2,2,ceil_mode=True)),
('conv2',nn.Conv2d(10,16,3,1)),
('relu2',nn.PReLU(16)),
('conv3',nn.Conv2d(16,32,3,1)),
('prelu3',nn.PReLU(32))
])
self.conv4_1 = nn.Conv2d(32,2,1,1)
self.conv4_2 = nn.Conv2d(32,4,1,1)
weigths = np.load('')[()]
for n,p in self.named_parameters():
p.data = torch.FloatTensor(weigths(n))
def forword(self,x):
'''
:param :输入 batchsize 图片通道 3 高 宽
:return:输出浮点型的tensor
batchsize 4,
atchsize 2
'''
x = self.feature(x)
face_exit = self.conv4_1(x)
xy_exit = self.conv4_2(x)
face_exit = nn.functional.softmax(face_exit,dim=-1)
return xy_exit,face_exit
class RNet(nn.Module):
def __init__(self):
super(RNet, self).__init__()
self.features = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(3, 28, 3, 1)),
('prelu1', nn.PReLU(28)),
('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)),
('conv2', nn.Conv2d(28, 48, 3, 1)),
('prelu2', nn.PReLU(48)),
('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)),
('conv3', nn.Conv2d(48, 64, 2, 1)),
('prelu3', nn.PReLU(64)),
('flatten', Flatten()),
('conv4', nn.Linear(576, 128)),
('prelu4', nn.PReLU(128))
]))
self.conv5_1 = nn.Linear(128, 2)
self.conv5_2 = nn.Linear(128, 4)
weights = np.load('')[()]
for n, p in self.named_parameters():
p.data = torch.FloatTensor(weights[n])
def forward(self, x):
"""
:param :输入 batchsize 图片通道 3 高 宽
:return:输出浮点型的tensor
batchsize 4,
atchsize 2
"""
x = self.features(x)
face_exit = self.conv5_1(x)
xy_exit = self.conv5_2(x)
face_exit = nn.functional.softmax(face_exit, dim=-1)
return xy_exit, face_exit
class ONet(nn.Module):
def __init__(self):
super(ONet, self).__init__()
self.features = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(3, 32, 3, 1)),
('prelu1', nn.PReLU(32)),
('pool1', nn.MaxPool2d(3, 2, ceil_mode=True)),
('conv2', nn.Conv2d(32, 64, 3, 1)),
('prelu2', nn.PReLU(64)),
('pool2', nn.MaxPool2d(3, 2, ceil_mode=True)),
('conv3', nn.Conv2d(64, 64, 3, 1)),
('prelu3', nn.PReLU(64)),
('pool3', nn.MaxPool2d(2, 2, ceil_mode=True)),
('conv4', nn.Conv2d(64, 128, 2, 1)),
('prelu4', nn.PReLU(128)),
('flatten', Flatten()),
('conv5', nn.Linear(1152, 256)),
('drop5', nn.Dropout(0.25)),
('prelu5', nn.PReLU(256)),
]))
self.conv6_1 = nn.Linear(256, 2)
self.conv6_2 = nn.Linear(256, 4)
self.conv6_3 = nn.Linear(256, 10)
weights = np.load('')[()]
for n, p in self.named_parameters():
p.data = torch.FloatTensor(weights[n])
def forward(self, x):
"""
:param :输入 batchsize 图片通道 3 高 宽
:return:输出浮点型的tensor
batchsize 10,
batchsize 4,
atchsize 2
"""
x = self.features(x)
face_exit = self.conv6_1(x)
xy_exit = self.conv6_2(x)
dot_exit = self.conv6_3(x)
face_exit = nn.functional.softmax(face_exit, dim=-1)
return dot_exit, xy_exit, face_exit``
mtcnn 基础网络架构
最新推荐文章于 2023-03-19 15:28:37 发布