导语
听说铁汁萌想找入门级练手项目,那小编今天就给大家带来一个背景替换实战小项目。
刚入门学习深度学习的小伙伴,可以看一看~
将 BackgroundMattingV2 项目稍微魔改了一下,让他在可以选择单一图片的基础上,可以把抠好的图片贴在自定义的背景图上,这样就可以让照片中的人物,出现在任何背景上。是不是很有意思?
项目说明
项目结构
我们先看一下项目的结构,如图:
其中,model文件夹放的是模型文件,模型文件的下载地址为:
https://pan.baidu.com/s/1dNDJIOjxIUV3Q30vpp0A0w?passwd=vtx6
下载该模型放到model文件夹下。
依赖文件-requirements.txt,说明一下,pytorch的安装需要使用官网给出的,避免显卡驱动对应不上。
依赖文件如下:
kornia==0.4.1
tensorboard==2.3.0
torch==1.7.0
torchvision==0.8.1
tqdm==4.51.0
opencv-python==4.4.0.44
onnxruntime==1.6.0
数据准备
我们需要准备一张照片以及照片的背景图,和你需要替换的图片。我这边选择的是BackgroundMattingV2给出的一些参考图,原始图与背景图如下:
新的背景图(我随便找的)如下:
替换背景图代码
不废话了,上核心代码。
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2021/11/14 21:24
# @Author : 剑客阿良_ALiang
# @Site :
# @File : inferance_hy.py
import argparse
import torch
import os
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.transforms.functional import to_pil_image
from threading import Thread
from tqdm import tqdm
from torch.utils.data import Dataset
from PIL import Image
from typing import Callable, Optional, List, Tuple
import glob
from torch import nn
from torchvision.models.resnet import ResNet, Bottleneck
from torch import Tensor
import torchvision
import numpy as np
import cv2
import uuid
# --------------- hy ---------------
class HomographicAlignment:
"""
Apply homographic alignment on background to match with the source image.
"""
def __init__(self):
self.detector = cv2.ORB_create()
self.matcher = cv2.DescriptorMatcher_create(cv2.DESCRIPTOR_MATCHER_BRUTEFORCE)
def __call__(self, src, bgr):
src = np.asarray(src)
bgr = np.asarray(bgr)
keypoints_src, descriptors_src = self.detector.detectAndCompute(src, None)
keypoints_bgr, descriptors_bgr = self.detector.detectAndCompute(bgr, None)
matches = self.matcher.match(descriptors_bgr, descriptors_src, None)
matches.sort(key=lambda x: x.distance, reverse=False)
num_good_matches = int(len(matches) * 0.15)
matches = matches[:num_good_matches]
points_src = np.zeros((len(matches), 2), dtype=np.float32)
points_bgr = np.zeros((len(matches), 2), dtype=np.float32)
for i, match in enumerate(matches):
points_src[i, :] = keypoints_src[match.trainIdx].pt
points_bgr[i, :] = keypoints_bgr[match.queryIdx].pt
H, _ = cv2.findHomography(points_bgr, points_src, cv2.RANSAC)
h, w = src.shape[:2]
bgr = cv2.warpPerspective(bgr, H, (w, h))
msk = cv2.warpPerspective(np.ones((h, w)), H, (w, h))
# For areas that is outside of the background,
# We just copy pixels from the source.
bgr[msk != 1] = src[msk != 1]
src = Image.fromarray(src)
bgr = Image.fromarray(bgr)
return src, bgr
class Refiner(nn.Module):
# For TorchScript export optimization.
__constants__ = ['kernel_size', 'patch_crop_method', 'patch_replace_method']
def __init__(self,
mode: str,
sample_pixels: int,
threshold: float,
kernel_size: int = 3,
prevent_oversampling: bool = True,
patch_crop_method: str = 'unfold',
patch_replace_method: str = 'scatter_nd'):
super().__init__()
assert mode in ['full', 'sampling', 'thresholding']
assert kernel_size in [1, 3]
assert patch_crop_method in ['unfold', 'roi_align', 'gather']
assert patch_replace_method in ['scatter_nd', 'scatter_element']
self.mode = mode
self.sample_pixels = sample_pixels
self.threshold = threshold
self.kernel_size = kernel_size
self.prevent_oversampling = prevent_oversampling
self.patch_crop_method = patch_crop_method
self.patch_replace_method = patch_replace_method
channels = [32, 24, 16, 12, 4]
self.conv1 = nn.Conv2d(channels[0] + 6 + 4, channels[1], kernel_size, bias=False)
self.bn1 = nn.BatchNorm2d(channels[1])
self.conv2 = nn.Conv2d(channels[1], channels[2], kernel_size, bias=False)
self.bn2 = nn.BatchNorm2d(channels[2])
self.conv3 = nn.Conv2d(channels[2] + 6, channels[3], kernel_size, bias=False)
self.bn3 = nn.BatchNorm2d(channels[3])
self.conv4 = nn.Conv2d(channels[3], channels[4], kernel_size, bias=True)
self.relu = nn.ReLU(True)
def forward(self,
src: torch.Tensor,
bgr: torch.Tensor,
pha: torch.Tensor,
fgr: torch.Tensor,
err: torch.Tensor,
hid: torch.Tensor):
H_full, W_full = src.shape[2:]
H_half, W_half = H_full // 2, W_full // 2
H_quat, W_quat = H_full // 4, W_full // 4
src_bgr = torch.cat([src, bgr], dim=1)
if self.mode != 'full':
err = F.interpolate(err, (H_quat, W_quat), mode='bilinear', align_corners=False)
ref = self.select_refinement_regions(err)
idx = torch.nonzero(ref.squeeze(1))
idx = idx[:, 0], idx[:, 1], idx[:, 2]
if idx[0].size(0) > 0:
x = torch.cat([hid, pha, fgr], dim=1)
x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False)
x = self.crop_patch(x, idx, 2, 3 if self.kernel_size == 3 else 0)
y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False)
y = self.crop_patch(y, idx, 2, 3 if self.kernel_size == 3 else 0)
x = self.conv1(torch.cat([x, y], dim=1))
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = F.interpolate(x, 8 if self.kernel_size == 3 else 4, mode='nearest')
y = self.crop_patch(src_bgr, idx, 4, 2 if self.kernel_size == 3 else 0)
x = self.conv3(torch.cat([x, y], dim=1))
x = self.bn3(x)
x = self.relu(x)
x = self.conv4(x)
out = torch.cat([pha, fgr], dim=1)
out = F.interpolate(out, (H_full, W_full), mode='bilinear', align_corners=False)
out = self.replace_patch(out, x, idx)
pha = out[:, :1]
fgr = out[:, 1:]
else:
pha = F.interpolate(pha, (H_full, W_full), mode='bilinear', align_corners=False)
fgr = F.interpolate(fgr, (H_full, W_full), mode='bilinear', align_corners=False)
else:
x = torch.cat([hid, pha, fgr], dim=1)
x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False)
y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False)
if self.kernel_size == 3:
x = F.pad(x, (3, 3, 3, 3))
y = F.pad(y, (3, 3, 3, 3))
x = self.conv1(torch.cat([x, y], dim=1))
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
if self.kernel_size == 3:
x = F.interpolate(x, (H_full + 4, W_full + 4))
y = F.pad(src_bgr, (2, 2, 2, 2))
els