基于深度学习的图像视频人像分割背景替换系统

本文详细介绍了基于深度学习的图像和视频人像分割背景替换系统,探讨了其研究背景与意义,包括在影视制作、虚拟现实和视频通信等领域的广泛应用。文章介绍了模型结构,如MobileNetV2结合通道注意力模块(CBAM)的肖像分割模型,以及系统中的人像抠图和证件照生成模块。通过训练和优化,该系统在保持语义分割鲁棒性的同时提高了推理速度,为实际应用提供了更高可用性。
摘要由CSDN通过智能技术生成

1.研究背景与意义

项目参考AAAI Association for the Advancement of Artificial Intelligence

研究背景与意义

近年来,随着深度学习技术的快速发展,图像和视频处理领域取得了巨大的进展。其中,图像和视频人像分割背景替换系统是一个备受关注的研究方向。人像分割是指将图像或视频中的人物与背景进行有效的分离,而背景替换则是指将原始图像或视频中的背景替换为新的背景,从而创造出具有不同环境和场景的视觉效果。

传统的人像分割方法通常基于图像处理技术,如颜色分割、边缘检测和区域生长等。然而,这些方法往往需要手动选择特征和参数,且对于复杂的场景和图像质量较差的情况下效果不佳。而基于深度学习的人像分割方法通过使用深度神经网络模型,可以自动学习图像和视频中的人物和背景之间的复杂关系,从而实现更准确和鲁棒的人像分割效果。

图像和视频人像分割背景替换系统具有广泛的应用前景。首先,它可以应用于电影、电视剧和广告等影视制作领域。通过使用该系统,制片人可以轻松地将演员放置在不同的场景中,从而实现更加逼真和引人注目的视觉效果。其次,该系统还可以应用于虚拟现实和增强现实技术中。通过将真实世界中的人物与虚拟场景相结合,可以创造出更加沉浸式和逼真的虚拟体验。此外,该系统还可以应用于视频会议和视频通话等领域,通过实时的人像分割和背景替换,可以为用户提供更加个性化和有趣的视频通信体验。

然而,目前基于深度学习的图像和视频人像分割背景替换系统仍然存在一些挑战和问题。首先,由于深度学习模型的复杂性和计算需求较高,系统的实时性和实时性仍然是一个挑战。其次,对于复杂的场景和图像质量较差的情况下,现有的人像分割算法仍然存在一定的误差和不准确性。此外,由于缺乏大规模的标注数据集,深度学习模型的泛化能力和鲁棒性仍然有待提高。

因此,进一步研究和改进基于深度学习的图像和视频人像分割背景替换系统具有重要的意义。通过提高算法的准确性和实时性,可以为影视制作、虚拟现实和视频通信等领域提供更加高质量和个性化的解决方案。同时,通过构建更大规模的标注数据集和改进深度学习模型,可以提高系统的泛化能力和鲁棒性,从而推动该领域的进一步发展。

2.图片演示

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3.视频演示

基于深度学习的图像视频人像分割背景替换系统_哔哩哔哩_bilibili

4.系统简介

在图像合成或电影制作中,经常会有对目标物体的背景进行替换的需求。人像抠图是自然图像抠图领域中的子任务,该任务除了需要预测复杂图片中人物前景的不透明度外,还需要让预测达到发丝级的精度。在现有的大多数人像抠图算法中,遮罩的预测需要额外的三元图作为输入,并且具有较慢的推理速度,这使得算法难以在实时性应用中得到使用。目前,MODNet是一种无需额外输入,并兼具快速推理和精确预测能力的深度学习人像抠图模型。在 MODNet基础上,本文设计了一种具有创新结构的无需三元图的人像抠图模型,并进一步实现了证件照生成系统。
人像抠图模型基于卷积神经网络实现,并以MobileNetV2骨干网络为编码器,以语义分支、细节分支和融合分支三个分支网络将为解码器,对输入图片进行遮罩预测。模型创新地在数据集制作中采用了数据增强和数据融合方案,并且在语义分支中引入了通道注意力模块CBAM(Convolutional Block Attention Modulc,CBAM)。实验结果表明,这些优化工作使模型在保持语义分割鲁棒性的同时,获得了很好的细节预测能力。
人像抠图模型完成遮罩预测后,证件照裁剪模块经过人脸定位、证件照生成和证件照美化,实现了证件照的端到端生成。随后,证件照生成系统在平台服务器上得到部署和使用。系统性能测试结果表明,模型虽然没有杰出的细节预测性能,但在应用层面上,模型无需额外输入,并且比MODNet抠图模型有约14%的推理速度提升和更好的语义分割性能,从而使证件照生成系统获得了更高的可用性。
在这里插入图片描述

5.核心代码讲解

5.1 bg_replace.py

下面是我封装的类,包含了代码中最核心的部分:


class BackgroundReplacer:
    def __init__(self, config_path, input_shape, img_path=None, video_path='./data/1.mp4', bg_img_path='./data/2.png', bg_video_path=None, save_dir='./output', use_optic_flow=False, soft_predict=True, add_argmax=False, test_speed=False):
        self.config_path = config_path
        self.input_shape = input_shape
        self.img_path = img_path
        self.video_path = video_path
        self.bg_img_path = bg_img_path
        self.bg_video_path = bg_video_path
        self.save_dir = save_dir
        self.use_optic_flow = use_optic_flow
        self.soft_predict = soft_predict
        self.add_argmax = add_argmax
        self.test_speed = test_speed

    def replace_background(self):
        args = self._parse_args()
        env_info = get_sys_env()
        args.use_gpu = True if env_info['Paddle compiled with cuda'] and env_info['GPUs used'] else False
        predictor = Predictor(args)

        if not osp.exists(args.save_dir):
            os.makedirs(args.save_dir)

        # 图像背景替换
        if args.img_path is not None:
            if not osp.exists(args.img_path):
                raise Exception('The --img_path is not existed: {}'.format(args.img_path))
            img = cv2.imread(args.img_path)
            bg = self._get_bg_img(args.bg_img_path, img.shape)

            comb = predictor.run(img, bg)

            save_name = osp.basename(args.img_path)
            save_path = osp.join(args.save_dir, save_name)
            cv2.imwrite(save_path, comb)
        # 视频背景替换
        else:
            # 获取背景:如果提供背景视频则以背景视频作为背景,否则采用提供的背景图片
            if args.bg_video_path is not None:
                if not osp.exists(args.bg_video_path):
                    raise Exception('The --bg_video_path is not existed: {}'.format(args.bg_video_path))
                is_video_bg = True
            else:
                bg = self._get_bg_img(args.bg_img_path, args.input_shape)
                is_video_bg = False

            # 视频预测
            if args.video_path is not None:
                logger.info('Please wait. It is computing......')
                if not osp.exists(args.video_path):
                    raise Exception('The --video_path is not existed: {}'.format(args.video_path))

                cap_video = cv2.VideoCapture(args.video_path)
                fps = cap_video.get(cv2.CAP_PROP_FPS)
                width = int(cap_video.get(cv2.CAP_PROP_FRAME_WIDTH))
                height = int(cap_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
                save_name = osp.basename(args.video_path)
                save_name = save_name.split('.')[0]
                save_path = osp.join(args.save_dir, save_name + '.avi')

                cap_out = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, (width, height))

                if is_video_bg:
                    cap_bg = cv2.VideoCapture(args.bg_video_path)
                    frames_bg = cap_bg.get(cv2.CAP_PROP_FRAME_COUNT)
                    current_bg = 1
                frame_num = 0
                while cap_video.isOpened():
                    ret, frame = cap_video.read()
                    if ret:
                        #读取背景帧
                        if is_video_bg:
                            ret_bg, bg = cap_bg.read()
                            if ret_bg:
                                if current_bg == frames_bg:
                                    current_bg = 1
                                    cap_bg.set(cv2.CAP_PROP_POS_FRAMES, 0)
                            else:
                                break
                            current_bg += 1

                        comb = predictor.run(frame, bg)

                        cap_out.write(comb)
                        frame_num += 1
                        logger.info('Processing frame {}'.format(frame_num))
                    else:
                        break

                if is_video_bg:
                    cap_bg.release()
                cap_video.release()
                cap_out.release()

            # 当没有输入预测图像和视频的时候,则打开摄像头
            else:
                cap_video = cv2.VideoCapture(0)
                if not cap_video.isOpened():
                    raise IOError("Error opening video stream or file, "
                                  "--video_path whether existing: {}"
                                  " or camera whether working".format(args.video_path))
                    return

                if is_video_bg:
                    cap_bg = cv2.VideoCapture(args.bg_video_path)
                    frames_bg = cap_bg.get(cv2.CAP_PROP_FRAME_COUNT)
                    current_bg = 1

                while cap_video.isOpened():
                    ret, frame = cap_video.read()
                    if ret:
                        #读取背景帧
                        if is_video_bg:
                            ret_bg, bg = cap_bg.read()
                            if ret_bg:
                                if current_bg == frames_bg:
                                    current_bg = 1
                                    cap_bg.set(cv2.CAP_PROP_POS_FRAMES, 0)
                            else:
                                break
                            current_bg += 1

                        comb = predictor.run(frame, bg)

                        cv2.imshow('HumanSegmentation', comb)
                        if cv2.waitKey(1) & 0xFF == ord('q'):
                            break
                    else:
                        break
                if is_video_bg:
                    cap_bg.release()
                cap_video.release()
        if args.test_speed:
            timer = predictor.cost_averager
            logger.info(
                'Model inference time per image: {}\nFPS: {}\nNum of images: {}'.format(timer.get_average(), 1 / timer.get_average(), timer._cnt))

    def _parse_args(self):
        parser = argparse.ArgumentParser(description='PP-HumanSeg inference for video')
        parser.add_argument("--config", dest="cfg", help="The config file.", default=self.config_path, type=str)
        parser.add_argument("--input_shape", dest="input_shape", help="The image shape [h, w] for net inputs.", nargs=2, default=self.input_shape, type=int)
        parser.add_argument('--img_path', dest='img_path', help='Image including human', type=str, default=self.img_path)
        parser.add_argument('--video_path', dest='video_path', help='Video path for inference', type=str, default=self.video_path)
        parser.add_argument('--bg_img_path', dest='bg_img_path', help='Background image path for replacing. If not specified, a white background is used', type=str, default=self.bg_img_path)
        parser.add_argument('--bg_video_path', dest='bg_video_path', help='Background video path for replacing', type=str, default=self.bg_video_path)
        parser.add_argument('--save_dir', dest='save_dir', help='The directory for saving the inference results', type=str, default=self.save_dir)
        parser.add_argument('--use_optic_flow', dest='use_optic_flow', help='Use optical flow for post-processing.', action='store_true')
        parser.add_argument('--soft_predict', dest='soft_predict', default=self.soft_predict, type=eval, choices=[True, False], help='Whether to use predict results with transparency')
        parser.add_argument('--add_argmax', dest='add_argmax', help='Perform argmax operation on the predict result.', action='store_true')
        parser.add_argument('--test_speed', dest='test_speed', help='Whether to test inference speed', action='store_true')
        return parser.parse_args()

    def _get_bg_img(self, bg_img_path, img_shape):
        if bg_img_path is None:
            bg = 255 * np.ones(img_shape)
        elif not osp.exists(bg_img_path):
            raise Exception('The --bg_img_path is not existed: {}'.format(bg_img_path))
        else:
            bg = cv2.imread(bg_img_path)
        return bg

该程序文件名为bg_replace.py,是一个用于背景替换的程序。程序使用了PaddlePaddle深度学习框架,通过加载预训练模型对图像或视频中的人进行分割,并将其替换为指定的背景图像或背景视频。

程序首先通过命令行参数解析获取配置文件路径、输入图像尺寸、图像路径、视频路径、背景图像路径、背景视频路径、保存结果的目录等参数。

然后,程序调用get_sys_env()函数获取系统环境信息,并根据环境信息判断是否使用GPU进行推理。

接下来,程序调用Predictor类的构造函数创建一个预测器对象。

如果指定了图像路径,则程序读取图像和背景图像,并调用预测器的run()方法对图像进行背景替换,并将结果保存到指定的目录中。

如果指定了视频路径,则程序读取视频和背景图像(如果指定了背景视频路径),并逐帧调用预测器的run()方法对每一帧图像进行背景替换,并将结果保存为视频文件。

如果既没有指定图像路径也没有指定视频路径,则程序打开摄像头,实时读取摄像头的图像,并进行背景替换。

最后,如果指定了--test_speed参数,则程序输出模型推理时间和帧率的统计信息。

程序还定义了一个get_bg_img()函数,用于获取背景图像。如果没有指定背景图像路径,则创建一个全白图像作为背景。如果指定了背景图像路径,则读取背景图像。

程序的入口是if __name__ == "__main__":,在这里解析命令行参数,并调用background_replace()函数进行背景替换。

5.2 cbam_module.py


class CBAM(nn.Module):
    def __init__(self, channels):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(channels)
        self.spatial_attention = SpatialAttention()

    def forward(self, x):
        x = self.channel_attention(x)
        x = self.spatial_attention(x)
        return x


class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.shared_MLP = nn.Sequential(
            nn.Linear(channels, channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(channels // reduction_ratio, channels)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_pool = self.avg_pool(x).view(x.size(0), -1)
        max_pool = self.max_pool(x).view(x.size(0), -1)
        avg_out = self.shared_MLP(avg_pool)
        max_out = self.shared_MLP(max_pool)
        out = avg_out + max_out
        return self.sigmoid(out).view(x.size(0), x.size(1), 1, 1)


class SpatialAttention(nn.Module):
    def __init__(self):
        super(SpatialAttention, self).__init__()
        self.conv1 = nn.Conv2d(2, 1, kernel_size=7, padding=3)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_pool = torch.mean(x, dim=1, keepdim=True)
        max_pool, _ = torch.max(x, dim=1, keepdim=True)
        pool = torch.cat([avg_pool, max_pool], dim=1)
        out = self.conv1(pool)
        return self.sigmoid(out)

这个程序文件是一个实现了CBAM(Convolutional Block Attention Module)模块的PyTorch模型。CBAM模块是一种用于图像分类和目标检测任务的注意力机制模块,它可以自适应地学习通道注意力和空间注意力,以提高模型的性能。

该文件定义了三个类:ChannelAttention、SpatialAttention和CBAM。ChannelAttention类实现了通道注意力机制,它通过自适应平均池化和自适应最大池化来提取输入特征图的平均值和最大值,并通过一个共享的多层感知机(MLP)来学习通道注意力权重。SpatialAttention类实现了空间注意力机制,它通过对输入特征图进行平均池化和最大池化,并将它们连接起来,然后通过一个卷积层来学习空间注意力权重。CBAM类将通道注意力和空间注意力结合起来,形成一个完整的CBAM模块。

在forward方法中,CBAM类首先调用channel_attention方法来计算通道注意力权重,然后调用spatial_attention方法来计算空间注意力权重,最后将两者相乘得到最终的注意力图。

5.3 id_photo_cropper.py


class IDPhotoCropper:
    def __init__(self):
        self.face_detector = cv2.dnn.readNetFromCaffe(
            'deploy.prototxt.txt', 'res10_300x300_ssd_iter_140000.caffemodel')

    def crop(self, input_image):
        blob = self._preprocess_input(input_image)
        detections = self._detect_faces(blob)

        if detections is not None:
            face = self._crop_face(input_image, detections)
            return face

        return None

    def _preprocess_input(self, input_image):
        (h, w) = input_image.shape[:2]
        blob = cv2.dnn.blobFromImage(cv2.resize(input_image, (300, 300)), 1.0,
                                     (300, 300), (104.0, 177.0, 123.0))
        return blob

    def _detect_faces(self, blob):
        self.face_detector.setInput(blob)
        detections = self.face_detector.forward()

        i = np.argmax(detections[0, 0, :, 2])
        confidence = detections[0, 0, i, 2]

        if confidence > 0.5:
            return detections[0, 0, i, 3:7]

        return None

    def _crop_face(self, input_image, detections):
        (h, w) = input_image.shape[:2]
        box = detections * np.array([w, h, w, h])
        (startX, startY, endX, endY) = box.astype("int")
        face = input_image[startY:endY, startX:endX]
        return face

这个程序文件名为id_photo_cropper.py,它是一个用于裁剪身份证照片的类。该类包含一个构造函数和一个crop方法。

构造函数中,程序使用cv2.dnn.readNetFromCaffe函数加载了一个基于Caffe框架的人脸检测器模型。这个模型是通过两个文件’deploy.prototxt.txt’和’res10_300x300_ssd_iter_140000.caffemodel’来定义和训练的。

crop方法接受一个输入图像作为参数,并返回裁剪后的人脸图像。在方法中,程序首先获取输入图像的高度和宽度。然后,程序使用cv2.dnn.blobFromImage函数将输入图像调整为300x300大小的blob,并进行一些预处理操作。

接下来,程序将blob输入到人脸检测器中,并获取检测结果。程序假设只有一个人脸被检测到,因此选择具有最高置信度的检测结果作为目标人脸。如果置信度大于0.5,则程序根据检测结果计算出人脸的边界框,并从输入图像中裁剪出人脸图像。

最后,如果没有检测到人脸或置信度较低,则返回None。

这个程序使用了OpenCV和NumPy库来进行图像处理和人脸检测。

5.4 portrait_segmentation_model.py


class PortraitSegmentationModel(nn.Module):
    def __init__(self):
        super(PortraitSegmentationModel, self).__init__()
        self.backbone = mobilenet_v2(pretrained=True).features
        self.semantic_branch = self._create_branch(128, with_cbam=True)
        self.detail_branch = self._create_branch(128, with_cbam=False)
        self.fusion_branch = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 1, kernel_size=3, padding=1),
            nn.Sigmoid()
        )

    def _create_branch(self, out_channels, with_cbam):
        layers = [
            nn.Conv2d(1280, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU()
        ]
        if with_cbam:
            layers.insert(2, CBAM(out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        features = self.backbone(x)
        semantic_output = self.semantic_branch(features)
        detail_output = self.detail_branch(features)
        fusion_output = self.fusion_branch(torch.cat([semantic_output, detail_output], dim=1))
        return fusion_output

这个程序文件是一个肖像分割模型的定义。它使用了PyTorch库,并引入了一些必要的模块和函数。

这个模型的主要结构包括一个MobileNetV2的骨干网络,一个语义分支,一个细节分支和一个融合分支。

在初始化函数中,首先创建了一个MobileNetV2的骨干网络,并加载了预训练的权重。然后通过调用_create_branch函数创建了语义分支和细节分支。语义分支和细节分支都是由一系列卷积层和激活函数组成的。其中,语义分支还包括了一个CBAM模块,用于增强特征的表示能力。最后,创建了一个融合分支,它由两个卷积层和一个Sigmoid激活函数组成。

在前向传播函数中,首先将输入数据通过骨干网络得到特征表示。然后将特征表示分别输入到语义分支和细节分支中,得到相应的输出。最后,将语义分支和细节分支的输出在通道维度上进行拼接,并输入到融合分支中,得到最终的输出。

整个模型的目标是对输入的肖像图像进行分割,将人物和背景分离出来。

5.5 predict.py


class ModelPrediction:
    def __init__(self, cfg, model_path, image_path, save_dir, aug_pred=False, scales=1.0, flip_horizontal=False, flip_vertical=False, is_slide=False, crop_size=None, stride=None):
        self.cfg = cfg
        self.model_path = model_path
        self.image_path = image_path
        self.save_dir = save_dir
        self.aug_pred = aug_pred
        self.scales = scales
        self.flip_horizontal = flip_horizontal
        self.flip_vertical = flip_vertical
        self.is_slide = is_slide
        self.crop_size = crop_size
        self.stride = stride

    def get_image_list(self):
        """Get image list"""
        valid_suffix = [
            '.JPEG', '.jpeg', '.JPG', '.jpg', '.BMP', '.bmp', '.PNG', '.png'
        ]
        image_list = []
        image_dir = None
        if os.path.isfile(self.image_path):
            if os.path.splitext(self.image_path)[-1] in valid_suffix:
                image_list.append(self.image_path)
            else:
                image_dir = os.path.dirname(self.image_path)
                with open(self.image_path, 'r') as f:
                    for line in f:
                        line = line.strip()
                        if len(line.split()) > 1:
                            line = line.split()[0]
                        image_list.append(os.path.join(image_dir, line))
        elif os.path.isdir(self.image_path):
            image_dir = self.image_path
            for root, dirs, files in os.walk(self.image_path):
                for f in files:
                    if '.ipynb_checkpoints' in root:
                        continue
                    if os.path.splitext(f)[-1] in valid_suffix:
                        image_list.append(os.path.join(root, f))
        else:
            raise FileNotFoundError(
                '`--image_path` is not found. it should be an image file or a directory including images'
            )

        if len(image_list) == 0:
            raise RuntimeError('There are not image file in `--image_path`')

        return image_list, image_dir

    def predict(self):
        env_info = get_sys_env()
        place = 'gpu' if env_info['Paddle compiled with cuda'] and env_info[
            'GPUs used'] else 'cpu'

        paddle.set_device(place)
        if not self.cfg:
            raise RuntimeError('No configuration file specified.')

        cfg = Config(self.cfg)
        val_dataset = cfg.val_dataset
        if not val_dataset:
            raise RuntimeError(
                'The verification dataset is not specified in the configuration file.'
            )

        msg = '\n---------------Config Information---------------\n'
        msg += str(cfg)
        msg += '------------------------------------------------'
        logger.info(msg)

        model = cfg.model
        transforms = val_dataset.transforms
        image_list, image_dir = self.get_image_list()
        logger.info('Number of predict images = {}'.format(len(image_list)))

        config_check(cfg, val_dataset=val_dataset)

        predict(
            model,
            model_path=self.model_path,
            transforms=transforms,
            image_list=image_list,
            image_dir=image_dir,
            save_dir=self.save_dir,
            aug_pred=self.aug_pred,
            scales=self.scales,
            flip_horizontal=self.flip_horizontal,
            flip_vertical=self.flip_vertical,
            is_slide=self.is_slide,
            crop_size=self.crop_size,
            stride=self.stride,
        )

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Model prediction')

    # params of prediction
    parser.add_argument(
        "--config", dest="cfg", help="The config file.", default=None, type=str)
    parser.add_argument(
        '--model_path',
        dest='model_path',
        help='The path of model for prediction',
        type=str,
        default=None)
    parser.add_argument(
        '--image_path',
        dest='image_path',
        help=
        'The path of image, it can be a file or a directory including images',
        type=str,
        default=None)
    parser.add_argument(
        '--save_dir',
        dest='save_dir',
        help='The directory for saving the predicted results',
        type=str,
        default='./output/result')

    # augment for prediction
    parser.add_argument(
        '--aug_pred',
        dest='aug_pred',
        help='Whether to use mulit-scales and flip augment for prediction',
        action='store_true')
    parser.add_argument(
        '--scales',
        dest='scales',
        nargs='+',
        help='Scales for augment',
        type=float,
        default=1.0)
    parser.add_argument(
        '--flip_horizontal',
        dest='flip_horizontal',
        help='Whether to use flip horizontally augment',
        action='store_true')
    parser.add_argument(
        '--flip_vertical',
        dest='flip_vertical',
        help='Whether to use flip vertically augment',
        action='store_true')

    # sliding window prediction
    parser.add_argument(
        '--is_slide',
        dest='is_slide',
        help='Whether to prediction by sliding window',
        action='store_true')
    parser.add_argument(
        '--crop_size',
        dest='crop_size',
        nargs=2,
        help=
        'The crop size of sliding window, the first is width and the second is height.',
        type=int,
        default=None)
    parser.add_argument(
        '--stride',
        dest='stride',
        nargs=2,
        help=
        'The stride of sliding window, the first is width and the second is height.',
        type=int,
        default=None)

    args = parser.parse_args()
    model_prediction = ModelPrediction(args.cfg, args.model_path, args.image_path, args.save_dir, args.aug_pred, args.scales, args.flip_horizontal, args.flip_vertical, args.is_slide, args.crop_size, args.stride)
    model_prediction.predict()

该程序文件名为predict.py,主要用于模型预测。程序首先导入了必要的库和模块,包括argparse、os、paddle等。然后定义了一个parse_args函数,用于解析命令行参数。接下来定义了一个get_image_list函数,用于获取图片列表。然后定义了一个main函数,主要用于加载配置文件、验证数据集、进行预测等操作。最后,在程序的主入口处,调用parse_args函数解析命令行参数,并调用main函数进行模型预测。

5.6 train.py

class ModelTraining:
    def __init__(self, cfg, iters, batch_size, learning_rate, save_interval, resume_model, save_dir, keep_checkpoint_max, num_workers, do_eval, log_iters, use_vdl):
        self.cfg = cfg
        self.iters = iters
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.save_interval = save_interval
        self.resume_model = resume_model
        self.save_dir = save_dir
        self.keep_checkpoint_max = keep_checkpoint_max
        self.num_workers = num_workers
        self.do_eval = do_eval
        self.log_iters = log_iters
        self.use_vdl = use_vdl

    def train_model(self):
        env_info = get_sys_env()
        info = ['{}: {}'.format(k, v) for k, v in env_info.items()]
        info = '\n'.join(['', format('Environment Information', '-^48s')] + info +
                         ['-' * 48])
        logger.info(info)

        place = 'gpu' if env_info['Paddle compiled with cuda'] and env_info[
            'GPUs used'] else 'cpu'

        paddle.set_device(place)
        if not self.cfg:
            raise RuntimeError('No configuration file specified.')

        cfg = Config(
            self.cfg,
            learning_rate=self.learning_rate,
            iters=self.iters,
            batch_size=self.batch_size)

        train_dataset = cfg.train_dataset
        if train_dataset is None:
            raise RuntimeError(
                'The training dataset is not specified in the configuration file.')
        elif len(train_dataset) == 0:
            raise ValueError(
                'The length of train_dataset is 0. Please check if your dataset is valid'
            )
        val_dataset = cfg.val_dataset if self.do_eval else None
        losses = cfg.loss

        msg = '\n---------------Config Information---------------\n'
        msg += str(cfg)
        msg += '------------------------------------------------'
        logger.info(msg)

        config_check(cfg, train_dataset=train_dataset, val_dataset=val_dataset)

        train(
            cfg.model,
            train_dataset,
            val_dataset=val_dataset,
            optimizer=cfg.optimizer,
            save_dir=self.save_dir,
            iters=cfg.iters,
            batch_size=cfg.batch_size,
            resume_model=self.resume_model,
            save_interval=self.save_interval,
            log_iters=self.log_iters,
            num_workers=self.num_workers,
            use_vdl=self.use_vdl,
            losses=losses,
            keep_checkpoint_max=self.keep_checkpoint_max)

def parse_args():
    parser = argparse.ArgumentParser(description='Model training')
    # params of training
    parser.add_argument(
        "--config", dest="cfg", help="The config file.", default=None, type=str)
    parser.add_argument(
        '--iters',
        dest='iters',
        help='iters for training',
        type=int,
        default=None)
    parser.add_argument(
        '--batch_size',
        dest='batch_size',
        help='Mini batch size of one gpu or cpu',
        type=int,
        default=None)
    parser.add_argument(
        '--learning_rate',
        dest='learning_rate',
        help='Learning rate',
        type=float,
        default=None)
    parser.add_argument(
        '--save_interval',
        dest='save_interval',
        help='How many iters to save a model snapshot once during training.',
        type=int,
        default=1000)
    parser.add_argument(
        '--resume_model',
        dest='resume_model',
        help='The path of resume model',
        type=str,
        default=None)
    parser.add_argument(
        '--save_dir',
        dest='save_dir',
        help='The directory for saving the model snapshot',
        type=str,
        default='./output')
    parser.add_argument(
        '--keep_checkpoint_max',
        dest='keep_checkpoint_max',
        help='Maximum number of checkpoints to save',
        type=int,
        default=5)
    parser.add_argument(
        '--num_workers',
        dest='num_workers',
        help='Num workers for data loader',
        type=int,
        default=0)
    parser.add_argument(
        '--do_eval',
        dest='do_eval',
        help='Eval while training',
        action='store_true')
    parser.add_argument(
        '--log_iters',
        dest='log_iters',
        help='Display logging information at every log_iters',
        default=10,
        type=int)
    parser.add_argument(
        '--use_vdl',
        dest='use_vdl',
        help='Whether to record the data to VisualDL during training',
        action='store_true')

    return parser.parse_args()

if __name__ == '__main__':
    args = parse_args()
    model_training = ModelTraining(args.cfg, args.iters, args.batch_size, args.learning_rate, args.save_interval, args.resume_model, args.save_dir, args.keep_checkpoint_max, args.num_workers, args.do_eval, args.log_iters, args.use_vdl)
    model_training.train_model()

该程序文件名为train.py,是一个用于模型训练的脚本。主要功能包括解析命令行参数、获取系统环境信息、配置模型训练参数、加载数据集、进行模型训练等。

具体功能如下:

  1. 解析命令行参数,包括配置文件路径、训练迭代次数、批量大小、学习率等参数。
  2. 获取系统环境信息,包括是否编译了CUDA、使用的GPU数量等。
  3. 配置模型训练参数,包括学习率、迭代次数、批量大小等。
  4. 加载训练数据集和验证数据集。
  5. 进行模型训练,包括模型初始化、优化器配置、保存模型等。
  6. 可选择在训练过程中进行评估。
  7. 可选择使用VisualDL记录训练数据。

总体来说,该程序文件是一个用于配置和执行模型训练的脚本,通过命令行参数指定训练配置和参数,然后进行模型训练并保存模型。

6.系统整体结构

整体功能和构架概括:

该项目是一个基于深度学习的图像视频人像分割背景替换系统。它使用了多个程序文件来实现不同的功能,包括背景替换、注意力机制、人脸裁剪、肖像分割模型的定义、模型预测和模型训练等。

下表整理了每个文件的功能:

文件名功能
bg_replace.py背景替换程序
cbam_module.pyCBAM模块的定义
id_photo_cropper.py身份证照片裁剪类
portrait_segmentation_model.py肖像分割模型的定义
predict.py模型预测脚本
train.py模型训练脚本
ui.py用户界面脚本
val.py模型验证脚本
datasets\humanseg.py人像分割数据集类
deploy\infer.py模型推理脚本
scripts\optic_flow_process.py光流处理脚本
scripts\train.py模型训练脚本
scripts_init_.py脚本初始化

请注意,表格中只列出了给出的文件路径中的文件,其他文件可能也对整个项目起到了一定的作用,但无法在给定的路径中找到。

7.人像抠图模块

人像抠图模块中,在经过人像抠图数据集训练模型的帮助下,用户输入的包含人像的自然图片,能够被端到端地标记出图片中属于人物前景的像素。人像抠图模型在完成训练和保存后,将会进行细分、量化操作并转换成ONNX模型,然后在证件照裁剪模块中得到使用。而在人像抠图模块中,预测遮罩的准确度取决于人像抠图模型的预测效果,因此使用高质量的人像抠图模型十分重要。
人像抠图模型可以有两类实现方法:一种是基于三元图的方法,而另一种是无需三元图的方法。前者需要三元图作为额外输入,如图2.2(b)所示,并通过传统算法或深度学习方法进行预测。然而这种方法具有更高的使用门槛,用户需要交互式地提供三元图或笔画,这对证件照生成系统中的推广使用是不利的。无需三元图的方法一般都通过深度学习方法实现,通过人像抠图数据集对模型进行训练,模型能够仅通过输入图片就完成遮罩的输出,如图所示。
在这里插入图片描述

深度人像抠图模型大多数采用卷积神经网络实现,并且模型的整体可以抽象为“编码器-解码器”结构。编码器采用经典的骨干网络或者全卷积网络实现,通过编码器,特征图能够包含图片中不同层次的语义信息。在语义分割任务中有许多骨干网络可以用来作为编码器,例如ShuffleNetl37l,DeepLabl32l和 MobileNet[3l等。这些网络都有层次深、通道多的特点,DeepLab系列网络,在语义分割任务中获得了很高的准确率,而MobileNet 以它体积小、效率高的特点获得了许多方法的青睐。对于移动端设备或边缘设备来说,MobileNet可以保证模型在具有高效推理速度的同时拥有毫不逊色的预测精准。解码器网络通常采用全卷积网络实现,通过对来自编码器的抽象特征进行解码,来对最终的遮罩进行预测。在语义分割任务中,有许多网络结构适合用来充当解码器,例如U-Netl33中的梯形结构,PSPNet[34中的金字塔结构等。
MODNet提出了一骨干三分支结构,也就是骨干网络作为编码器而三个分支网络作为解码器。相较于Xu等23]提出的粗分割-精炼的两分支结构,三分支结构引入的额外的细节分支,可以对图像过渡区域进行识别并预测出过渡区域的细节,而这与三元图输入的作用是相似的,因此这种结构对于无需三元图的方法十分适用。
模型由MobileNetV2[35]作为骨干网络,以及语义分割分支、细节预测分支和特征融合分支作为分支网络。其中,MobileNetV2作为特征编码器,将会通过卷积网络对图片中5种不同分辨率的特征进行提取,这些特征包含了清晰的表层信息和抽象的语义信息。在分支网络中,语义分割分支通过骨干网络的特征,完成输入图片中人物和背景的粗糙分割,而边缘区域的细节将会被忽略;细节预测分支通过骨干网络和语
义分支的特征,完成人物前景中边缘区域的识别,以反边缘区攻中承场和进R行础合。分离任务;特征融合分支将会把来自语义分支的特征以及细节分文的将位进P础自完成最终精细遮罩的预测任务。如图2.3所示,不同的子网络之间存在着特征信息的
传递,并且每个分支通过不同的标签来完成对应能力的学习。
在这里插入图片描述

在语义分割任务中,数据集只需要输入图片以及对应的遮罩作为标签即可。而对于精确度要求更高的人像抠图任务来说,数据集中还需要添加更加精确的遮罩。如图2.4所示,小型精细数据集通过数据增强后与粗糙数据集一起参与模型的训练,而训练完成之后的模型将会在基准图片上进行性能测试评估,如果模型产生的遮罩满足证件照生成系统的使用要求,那么模型将会被保存,并在经过剪枝、量化后转化为ONNX模型,在运行框架上得到部署。如果遮罩效果太差,那么混合数据集中两种数据集的比例或者模型网络结构会得到调整,从而获得最好的训练效果。

在这里插入图片描述

8.系统抠图模型设计

骨干网络

模型使用MobileNetV2作为骨干网络,MobileNetV2是一种可以用在大多数计算机视觉任务上的结构网络,并且可以根据用途选择不同的输入大小和参数因子,因此这可以减少模型的运算量,从而提升在移动端设备上的计算性能[35]。MobileNetV2的轻量和快速主要是因为引入了深度可分离卷积、线性瓶颈和逆残差模块。

普通卷积

如图所示,卷积操作是卷积核参数矩阵与图像局部区域中的数据矩阵计算内积,并将内积加入到新的特征图的过程。通过赋予卷积核的不同的权重参数,卷积操作能够对它感兴趣的特征进行提取,并反映在特征图中。数据输入通过卷积核的卷积操作,使得输入中的数据以不同的权重被提取,这就成为了特征图。而一个卷积层中通常有许多不同的卷积核,通过训练,这些卷积核就具有了不同的特征提取能力,当不同的卷积核对特征图进行卷积操作后,就产生多通道特征。
在这里插入图片描述

深度可分离卷积

深度可分离卷积在许多卷积神经网络结构中用到,它的主要原理是将一个卷积层分成两个子层,如表3.1所示,第一层是深度卷积,它作为一个轻量滤波器应用到每个通道上,第二层是一维卷积,称为点度卷积,它的作用是通过输入通道的线性组合来构建新的特征。
在深度卷积神经网络中,存在如ReLU变换的非线性激活函数,如果这些信息完整度较高,那么经过非线性的激活函数时会有不可避免的空间坍塌,导致了信息的损失。而在低维空间中,嵌入到低维空间时激活ReLU变换能会导致一些信息被滤除。所以为了解决问题,使用了线性瓶颈对非线性激活变换进行替代。因此卷积最后接入线性瓶颈。
标准的卷积以HwC的特征作为输入,并通过卷积核生成H * W *C,的输出,如果卷积核大小为k,那么总消耗将如公式(3.1)所示:
在这里插入图片描述

而使用了深度可分离卷积的成本如公式(3.2)所示:
在这里插入图片描述

在公式(3.2)中,前一项是深度卷积的成本,后一项是点卷积的成本。通过这一项改进,MobileNetV2在卷积核大小为3时减少了8-9倍的计算量,却获得了相比与Howard 等[36]一点点的精度损失。

逆残差模块

与残差模块不同的是,逆残差模块的两端不包含非线性变换。在残差结构中,特征通过模块先降维,然后再升维,呈现一个沙漏型结构,而逆残差结构的两端由线性瓶颈组成,特征在模块中先升维,再降维,这就是为什么它被称为逆残差结构。该改
在这里插入图片描述
动使得梯度在多个层之间的传播能力得到改进,并且内存损耗变得更小。对于一个线性瓶颈残差模块来说,假如输入通道数是C,卷积核大小为k,输出通道数是C,那么一个模块所需要的乘累加计算数将如公式所示:
在这里插入图片描述

与MobileNetV1[36和 ShuffleNet[3’l相比,MobileNetV2在相同的分辨率大小下,要达到相同的性能所需要的最大通道数和内存消耗均有下降。为了在大多数设备上兼容,模块通过标准的操作实现,可以在大多数框架上运行并保持最佳的性能。而且一些较大的临时数据不会在运行时被加载,这就减少了计算过程中的内存占用。

MobileNetV2的使用

框架提供了MobileNetV2基于ImageNet数据集的预训练权重。然而,该预训练权重更适合图像分类任务。为了让骨干网络能够更好地提取语义分割方面的特征,需要使用大型语义分割数据集对骨干进行预训练。预训练分为两个阶段,首先,预训练将基于骨干网络编码器和U-Net形状的语义解码器构成的语义分割模型,并将骨干网络的权重冻结,以Chen等[27]提供的大型语义分割数据集进行训练。训练集中 33426张图片将作为训练集,而余下的1000张图片将作为验证集。在第一阶段的训练中,模型在0.0001的学习率下训练10轮后收敛。然后,在第二个阶段中将会解冻骨干网络的权重,在相同的数据集下,以0.00001的低学习率训练10轮,选择低学习率是为了避免骨干网络的浅层受到梯度更新的影响,从而造成过拟合。
在训练集数据进入网络之前,需要对数据采取预处理操作,以便符合骨干网络的要求。首先,骨干网络以固定大小的图片数据集进行训练,因此需要同样大小的图片向量作为输入,网络才能正确地提取到图片中的语义特征.所以,训练集中的600800大小的图片将会通过压缩而不是裁剪的方式减小到224224,这是因为在人像抠图任务中背景的语义信息也比较重要,裁剪虽然会保留局部细节的语义但会失去一些全局信息。然后,骨干网络要求输入图片数据的数值范围从0255标准化为-11之间,这是为了与ImageNet 数据集的预处理相匹配。
随后,MobileNetV2将会对数据输入进行编码,也就是计算MobileNetV2前向传播过程中的卷积层特征。网络将会选择五个由浅到深的卷积层特征,作为后续解码分支的输入。从深层到浅层的特征维度分别是7T320、1414576,2828192、5656144以及11211296,越深的卷积层会具有更高级的语义特征,但缺少低层的全局信息,因此这些特征需要进行不同的组合和解码,才能让分支得到它们需要的信息。在训练过程中,骨干网络的权重将会被冻结,也就是损失函数对权重的反向传播将不会更新骨干网络的权重。
在这里插入图片描述

9.训练结果分析

图片在加载之前,其宽和高都被压缩成224,因为骨干网络 MobileNetV2预训练权重适用于224*224尺寸的输入,使用其他尺寸的输入骨干网络将不能正确地进行特征提取。除了压缩体积外,输入图片被标准化为-1到1,标签图片被标准化为0到1。随后数据集将会被随机打乱,并以32的批尺寸送入训练。
模型训练在RTX3090下展开,训练以Adam 作为优化器,训练过程采用恒定的学习率0.01,并经过30轮训练。如图所示,随着训练轮次的增加,训练集的语义分割损失、细节预测损失和遮罩损失均在持续下降,并且下降的速度不断减慢,这表明模型权重在训练过程逐渐收敛。

在这里插入图片描述

10.系统整合

下图完整源码&环境部署视频教程&自定义UI界面

在这里插入图片描述

参考博客《基于深度学习的图像视频人像分割背景替换系统》

11.参考文献


[1]苏常保,龚世才.基于深度学习的人物肖像全自动抠图算法[J].图学学报.2022,43(2).DOI:10.11996/JG.j.2095-302X.2022020247 .

[2]计梦予,袭肖明,于治楼.基于深度学习的语义分割方法综述[J].信息技术与信息化.2017,(10).DOI:10.3969/j.issn.1672-9528.2017.10.037 .

[3]孙巍.视觉感知特性指导下的自然图像抠图算法研究[J].北京交通大学.2015.

[4]管宇.图像和视频的便捷抠图技术研究[J].浙江大学理学院.2008.

[5]Zhao, He,Li, Huiqi,Cheng, Li.Improving retinal vessel segmentation with joint local loss by matting[J].Pattern Recognition: The Journal of the Pattern Recognition Society.2020.98DOI:10.1016/j.patcog.2019.107068 .

[6]Fan, Zhun,Lu, Jiewei,Wei, Caimin,等.A Hierarchical Image Matting Model for Blood Vessel Segmentation in Fundus Images[J].IEEE Transactions on Image Processing.2019,28(5).2367-2377.DOI:10.1109/TIP.2018.2885495 .

[7]Johnson, Jubin,Varnousfaderani, Ehsan Shahrian,Cholakkal, Hisham,等.Sparse Coding for Alpha Matting[J].IEEE Transactions on Image Processing.2016,25(7).3032-3043.DOI:10.1109/TIP.2016.2555705 .

[8]Chen,Qifeng,Li,等.KNN Matting[J].IEEE Transactions on Pattern Analysis & Machine Intelligence.2013,35(9).

[9]Fan, Jialue.Scribble Tracker: A Matting-Based Approach for Robust Tracking[J].IEEE Transactions on Pattern Analysis & Machine Intelligence.2012,34(8).

[10]Eduardo S. L. Gastal,Manuel M. Oliveira.Shared Sampling for Real‐Time Alpha Matting[J].Computer Graphics Forum.2010,29(2).575-584.DOI:10.1111/j.1467-8659.2009.01627.x .

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值