MTCNN训练与测试(Pytorch实现)

训练代码:

import os
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from Getdatas import getdatas
class Trainer:
    def __init__(self,net,save_para_path,dataset_path,iscuda=True):
        self.net=net
        self.save_para_path=save_para_path
        self.data_path=dataset_path
        self.iscuda=iscuda
        if self.iscuda:
            self.net.cuda()
        self.cond_lossfunc=nn.BCELoss()
        self.offset_lossfunc=nn.MSELoss()
        self.opt=optim.Adam(params=net.parameters(),lr=0.001)
        # 如果路径中有保存好的网络参数,则加载网络参数继续训练
        if os.path.exists(self.save_para_path):
            net.load_state_dict(torch.load(self.save_para_path))
    def train(self):
        facedataset=getdatas(self.data_path)
        dataloader=DataLoader(facedataset,batch_size=512,shuffle=True,num_workers=4)
        count = 0
        while True:
            for i,(img_data_,cond_,position_offset_,landmark_offset_) in enumerate(dataloader):
                if self.iscuda:
                    img_data_=img_data_.cuda()
                    cond_=cond_.cuda()
                    position_offset_=position_offset_.cuda()
                    landmark_offset_=landmark_offset_.cuda()
                    # 计算置信度的损失
                cond_output_,position_offset_output_,landmark_offset_output_=self.net(img_data_)
                cond_output=cond_output_.reshape(-1,1)
                # 部分样本不参与置信度损失的计算
                cond_mask=torch.lt(cond_,2) # 得到置信度小于2的掩码,若小于2掩码为1,大于等于2掩码为零
                cond=torch.masked_select(cond_,cond_mask)  # 根据cond_mask中将位置为1的对应于cond_中将置信度取出来
                cond_output=torch.masked_select(cond_output,cond_mask)
                cond_loss=self.cond_lossfunc(cond_output,cond)
                #计算建议框偏移量的损失
                # 负样本不参与偏移量损失的计算
                position_offset_mask=torch.gt(cond_,0)  # 得到置信度大于0的掩码,若小于等于掩码为0,大于0掩码为1
                position_offset=position_offset_[position_offset_mask[:,0]]
                position_offset_output=position_offset_output_[position_offset_mask[:,0]]
                position_offset_output=position_offset_output.reshape(-1,4)
                position_offset_loss=self.offset_lossfunc(position_offset_output,position_offset)

                # 计算五官偏移量的损失
                # 负样本不参与偏移量损失的计算
                landmark_offset_mask = torch.gt(cond_, 0)  # 得到置信度大于0的掩码,若小于等于掩码为0,大于0掩码为1
                landmark_offset = landmark_offset_[landmark_offset_mask[:, 0]]
                landmark_offset_output = landmark_offset_output_[landmark_offset_mask[:, 0]]
                landmark_offset_output = landmark_offset_output.reshape(-1, 10)
                landmark_offset_loss = self.offset_lossfunc(landmark_offset_output, landmark_offset)

                loss=cond_loss+position_offset_loss+landmark_offset_loss
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()       
                print("loss:",loss.float().item(),"cond_loss:",cond_loss.float().item(),"position_offset_loss",position_offset_loss.float().item(),"landmark_offset_loss",landmark_offset_loss.float().item())
            count=count+1
            print("第{0}轮训练完毕".format(count))
            torch.save(self.net.state_dict(),self.save_para_path)
            print("保存成功")

测试代码:

import torch
import numpy as np
from tools import nms,convert_to_square
import Nets
from torchvision import transforms
class P_net_detector:
    def __init__(self, P_net_param=r"P_net.pt", isCuda=True):
        self.isCuda = isCuda
        self.P_net = Nets.P_net()
        if self.isCuda:
            self.P_net.cuda()
        self.P_net.load_state_dict(torch.load(P_net_param))
        self.P_net.eval()
        self.m = transforms.ToTensor()
    def detect(self, image):
        w, h = image.size
        min_side = min(w, h)  # 以便做图像金字塔
        max_side=max(w,h)
        scale = 1  # 设置初始的缩放比例
        alpha=0.7
        boxes_ = []
        while min_side >= 12:
            img_data = self.m(image)
            if self.isCuda:
                img_data = img_data.cuda()
            img_data.unsqueeze_(0)  # 老版本的pytorch需要在第一个轴升一个维度作为批次
            cond_, position_offset_ ,landmark_offset_= self.P_net(img_data)
            _cond = cond_.cpu().detach()
            _position_offset = position_offset_.cpu().detach()
            boxes = nms(self.return_box(_cond, _position_offset, 0.6, scale), 0.5)  # (-1,5)
            boxes_.extend(boxes)
            scale *= alpha
            _w = int(w * scale)
            _h = int(h * scale)
            image = image.resize((_w, _h))
            min_side = min(_w, _h)
        return np.array(boxes_)  # 这里返回的是做完金字塔后对所有框作NMS后的框,因为计算时间要少一些
    def return_box(self, cond, offset, c, scale):
        _cond = cond[0][0]
        cond_ = torch.nonzero(torch.gt(_cond, c)).float()
        real_box = torch.tensor([])
        x1 = offset[:, 0][torch.gt(cond[:, 0], c)].view(-1, 1).float()
        y1 = offset[:, 1][torch.gt(cond[:, 0], c)].view(-1, 1).float()
        x2 = offset[:, 2][torch.gt(cond[:, 0], c)].view(-1, 1).float()
        y2 = offset[:, 3][torch.gt(cond[:, 0], c)].view(-1, 1).float()
        real_box = torch.cat((real_box, (cond_[:, 1:] * 2 + x1 * 12) / scale), 1)
        real_box = torch.cat((real_box, (cond_[:, :1] * 2 + y1 * 12) / scale), 1)
        real_box = torch.cat((real_box, (cond_[:, 1:] * 2 + x2 * 12 + 12) / scale), 1)
        real_box = torch.cat((real_box, (cond_[:, :1] * 2 + y2 * 12 + 12) / scale), 1)
        real_box = torch.cat((real_box, cond[:, 0][torch.gt(cond[:, 0], c)].view(-1, 1)), 1)
        return real_box
class R_net_detector:
    def __init__(self,R_net_param=r"R_net.pt",isCuda=True):
        self.isCuda=isCuda
        self.R_net = Nets.R_net()
        if self.isCuda:
            self.R_net.cuda()
        self.R_net.load_state_dict(torch.load(R_net_param))
        self.R_net.eval()
        self.m = transforms.ToTensor()
    def detect(self, image, P_net_boxes):
        if len(P_net_boxes)==0:
            return np.array([])
        _img_dataset = []  # 用于存放P网络输出的框对应在原图上的框
        _P_net_boxes = convert_to_square(P_net_boxes)  # 将P网络输出的框转化为正方形
        # 按照新的正方形在原图上抠图
        for _box in _P_net_boxes:
            _x1 = int(_box[0])
            _y1 = int(_box[1])
            _x2 = int(_box[2])
            _y2 = int(_box[3])
            img = image.crop((_x1, _y1, _x2, _y2))
            img = img.resize((24, 24))
            # 将扣的新图传入到R网络中
            img_data = self.m(img)
            _img_dataset.append(img_data)
        img_dataset = torch.stack(_img_dataset)
        if self.isCuda:
            img_dataset = img_dataset.cuda()
        cond_, position_offset_ ,landmark_offset_= self.R_net(img_dataset)
        cond = cond_.cpu().detach()
        position_offset = position_offset_.cpu().detach()
        boxes_ = []
        real_box = torch.tensor([])
        _P_net_boxes=torch.tensor(_P_net_boxes)
        _x1 = _P_net_boxes[:, 0][torch.gt(cond[:, 0], 0.6)].view(-1, 1).float()
        _y1 = _P_net_boxes[:, 1][torch.gt(cond[:, 0], 0.6)].view(-1, 1).float()
        _x2 = _P_net_boxes[:, 2][torch.gt(cond[:, 0], 0.6)].view(-1, 1).float()
        _y2 = _P_net_boxes[:, 3][torch.gt(cond[:, 0], 0.6)].view(-1, 1).float()
        x1=position_offset[:, 0][torch.gt(cond[:, 0], 0.6)].view(-1, 1).float()
        y1=position_offset[:, 1][torch.gt(cond[:, 0], 0.6)].view(-1, 1).float()
        x2=position_offset[:, 2][torch.gt(cond[:, 0], 0.6)].view(-1, 1).float()
        y2=position_offset[:, 3][torch.gt(cond[:, 0], 0.6)].view(-1, 1).float()
        real_box = torch.cat((real_box, _x1+(_x2-_x1)*x1), 1)
        real_box = torch.cat((real_box, _y1+(_y2-_y1)*y1), 1)
        real_box = torch.cat((real_box, _x2+(_x2-_x1)*x2), 1)
        real_box = torch.cat((real_box, _y2+(_y2-_y1)*y2), 1)
        real_box = torch.cat((real_box, cond[:, 0][torch.gt(cond[:, 0], 0.6)].view(-1, 1)), 1)
        boxes = nms(real_box, 0.5)  # (-1,5)
        boxes_.extend(boxes)
        return np.array(boxes_)
class O_net_detector:
    def __init__(self,O_net_param=r"O_net.pt",isCuda=True):
        self.isCuda=isCuda
        self.O_net = Nets.O_net()
        if self.isCuda:
            self.O_net.cuda()
        self.O_net.load_state_dict(torch.load(O_net_param))
        self.O_net.eval()
        self.m = transforms.ToTensor()
    def detect(self, image, R_net_boxes):
        if len(R_net_boxes)==0:
            return np.array([])
        _img_dataset = []  # 用于存放R网络输出的框对应在原图上的框
        _R_net_boxes = convert_to_square(R_net_boxes)  # 将R网络输出的框转化为正方形
        # 按照新的正方形在原图上抠图
        for _box in _R_net_boxes:
            _x1 = int(_box[0])
            _y1 = int(_box[1])
            _x2 = int(_box[2])
            _y2 = int(_box[3])
            img = image.crop((_x1, _y1, _x2, _y2))
            img = img.resize((48, 48))
            # 将扣的新图传入到O网络中
            img_data = self.m(img)
            _img_dataset.append(img_data)
        img_dataset = torch.stack(_img_dataset)
        if self.isCuda:
            img_dataset = img_dataset.cuda()
        cond_, position_offset_,landmark_offset_ = self.O_net(img_dataset)
        cond = cond_.cpu().detach()
        position_offset = position_offset_.cpu().detach()
        landmark_offset=landmark_offset_.cpu().detach()
        boxes=[]
        indeces, _ = np.where(cond >=0.99) # _表示占位,返回值舍弃不用
        for index in indeces:
            # 得到R网络输出框的坐标值作为O网络的建议框,以便反算得到O网络输出在原图上的真实框
            _box = _R_net_boxes[index]
            _x1 = int(_box[0])
            _y1 = int(_box[1])
            _x2 = int(_box[2])
            _y2 = int(_box[3])
            ow = _x2 - _x1
            oh = _y2 - _y1
            x1 = _x1 + ow * position_offset[index][0]
            y1 = _y1 + oh * position_offset[index][1]
            x2 = _x2 + ow * position_offset[index][2]
            y2 = _y2 + oh * position_offset[index][3]
            fx1 = _x1 + ow * landmark_offset[index][0]
            fy1 = _y1 + oh * landmark_offset[index][1]
            fx2 = _x1 + ow * landmark_offset[index][2]
            fy2 = _y1 + oh * landmark_offset[index][3]
            fx3 = _x1 + ow * landmark_offset[index][4]
            fy3 = _y1 + oh * landmark_offset[index][5]
            fx4 = _x1 + ow * landmark_offset[index][6]
            fy4 = _y1 + oh * landmark_offset[index][7]
            fx5 = _x1 + ow * landmark_offset[index][8]
            fy5 = _y1 + oh * landmark_offset[index][9]
            if abs(x1)<abs(fx3)<abs(x2) and abs(y1)<abs(fy3)<abs(y2):
               boxes.append([x1, y1, x2, y2, cond[index][0],fx1,fy1,fx2,fy2,fx3,fy3,fx4,fy4,fx5,fy5])
        return nms(np.array(boxes), 0.1)
if __name__=="__main__":
  # with torch.no_grad():
    image_file = r"test03.jpg"
    P_detect=P_net_detect.P_net_detector()
    R_detect=R_net_detect.R_net_detector()
    O_detect=O_net_detect.O_net_detector()
    with pimg.open(image_file) as img:
     torch.cuda.empty_cache()
     with torch.no_grad():
        img = img.convert("RGB")
        starttime=time.time()
        P_boxes=P_detect.detect(img)
        torch.cuda.empty_cache()
        endtime=time.time()
        P_time=endtime-starttime
        starttime=time.time()
        R_boxes=R_detect.detect(img,P_boxes)
        torch.cuda.empty_cache()
        endtime=time.time()
        R_time=endtime-starttime
        starttime=time.time()
        O_boxes=O_detect.detect(img,R_boxes)
        torch.cuda.empty_cache()
        endtime=time.time()
        O_time=endtime-starttime
        sumtime = P_time + R_time + O_time
        print("总时间:{0} P时间:{1} R时间:{2} O时间:{3}".format(sumtime, P_time, R_time, O_time))
        imgdraw = ImageDraw.Draw(img)
        for box in O_boxes:
            x1 = int(box[0])
            y1 = int(box[1])
            x2 = int(box[2])
            y2 = int(box[3])
            fx1=int(box[5])
            fy1 = int(box[6])
            fx2 = int(box[7])
            fy2 = int(box[8])
            fx3 = int(box[9])
            fy3 = int(box[10])
            fx4 = int(box[11])
            fy4 = int(box[12])
            fx5 = int(box[13])
            fy5 = int(box[14])
            print(box[4])
            imgdraw.rectangle((x1, y1, x2, y2), fill=None, outline="red",width=1)
            i=1
            imgdraw.rectangle((fx1, fy1, fx1 + i, fy1 + i), fill="#FFFF00", outline="#FFFF00")
            imgdraw.rectangle((fx2, fy2, fx2 + i, fy2 + i), fill="#FFFF00", outline="#FFFF00")
            imgdraw.rectangle((fx3, fy3, fx3 + i, fy3 + i), fill="#FFFF00", outline="#FFFF00")
            imgdraw.rectangle((fx4, fy4, fx4 + i, fy4 + i), fill="#FFFF00", outline="#FFFF00")
            imgdraw.rectangle((fx5, fy5, fx5 + i, fy5 + i), fill="#FFFF00", outline="#FFFF00")
        img.show()

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值