一、简介
有关模型的介绍:https://blog.csdn.net/qq_41684249/article/details/115342560?spm=1001.2014.3001.5502
预训练模型下载地址:
链接:https://pan.baidu.com/s/1ngIn3QZYrmKGbqUGiB7BWw
提取码:1mes
二、模型结构
三、Pytorch编程实现
编写MobileFaceNets.py文件。
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
import torch
class Flatten(Module):
def forward(self, input):
return input.view(input.size(0), -1)
class Conv_block(Module):
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
super(Conv_block, self).__init__()
self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False)
self.bn = BatchNorm2d(out_c)
self.prelu = PReLU(out_c)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.prelu(x)
return x
class Linear_block(Module):
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
super(Linear_block, self).__init__()
self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False)
self.bn = BatchNorm2d(out_c)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class Depth_Wise(Module):
def __init__(self, in_c, out_c, residual = False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
super(Depth_Wise, self).__init__()
self.conv = Conv_block(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
self.conv_dw = Conv_block(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride)
self.project = Linear_block(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
self.residual = residual
def forward(self, x):
if self.residual:
short_cut = x
x = self.conv(x)
x = self.conv_dw(x)
x = self.project(x)
if self.residual:
output = short_cut + x
else:
output = x
return output
class Residual(Module):
def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
super(Residual, self).__init__()
modules = []
for _ in range(num_block):
modules.append(Depth_Wise(c, c, residual=True, kernel=kernel, padding=padding, stride=stride, groups=groups))
self.model = Sequential(*modules)
def forward(self, x):
return self.model(x)
class MobileFaceNet(Module):
def __init__(self, embedding_size, out_h, out_w):
super(MobileFaceNet, self).__init__()
self.conv1 = Conv_block(3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1))
self.conv2_dw = Conv_block(64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
self.conv_23 = Depth_Wise(64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128)
self.conv_3 = Residual(64, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
self.conv_34 = Depth_Wise(64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256)
self.conv_4 = Residual(128, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
self.conv_45 = Depth_Wise(128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512)
self.conv_5 = Residual(128, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
self.conv_6_sep = Conv_block(128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
#self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(7,7), stride=(1, 1), padding=(0, 0))
#self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(4,7), stride=(1, 1), padding=(0, 0))
self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(out_h, out_w), stride=(1, 1), padding=(0, 0))
self.conv_6_flatten = Flatten()
self.linear = Linear(512, embedding_size, bias=False)
self.bn = BatchNorm1d(embedding_size)
def forward(self, x):
out = self.conv1(x)
out = self.conv2_dw(out)
out = self.conv_23(out)
out = self.conv_3(out)
out = self.conv_34(out)
out = self.conv_4(out)
out = self.conv_45(out)
out = self.conv_5(out)
out = self.conv_6_sep(out)
out = self.conv_6_dw(out)
out = self.conv_6_flatten(out)
out = self.linear(out)
out = self.bn(out)
return out
def load_model(self, model, model_path):
'''
加载预训练模型参数
:param model: 创建的模型
:param model_path: 预训练模型参数的地址
:return: 返回加载参数后的模型
'''
model_dict = model.state_dict() # 模型结构加载
pretrained_dict = torch.load(model_path)['state_dict'] # 加载模型参数
new_pretrained_dict = {}
for k in model_dict:
new_pretrained_dict[k] = pretrained_dict['backbone.' + k] # tradition training
model_dict.update(new_pretrained_dict) # 更新参数
model.load_state_dict(model_dict) # 加载参数
model = torch.nn.DataParallel(model).cuda() # 模型多gpu运行
return model
编写测试程序。
if __name__ == '__main__':
# 创建模型
model = MobileFaceNet(512, 7, 7)
# 加载模型
# model = model.load_model(model, "/home/malidong/workspace/mobilefacenet/Epoch_17.pt")
print(model)