2018年阿里的论文《Semantatic Human Matting》给出了人像抠图的一个新方法,这是Github上对这个论文的复现
一、网络主干、环境、数据集
1.1、网络主干
通过下面Semantic Human Matting网络图开始讲解SHM的网络设计:
SHM的网络过程:
T-Net
:本质是一个Encoder-Decoder 结构,作用是预测生成trimap图。输入是三通道原图,输出是三通道的trimap_pre图;- 三通道的trimap_pre图,经过softmax操作得到trimap_softmax图,再经过split 操作,得到都是单通道的背景图 B s B_s Bs,前景图 F s F_s Fs,不确定边界图 U s U_s Us;(这一步是在代码中的操作)
M-Net
:本质也是一个Encoder-Decoder 结构,作用是预测生成alpha 图。输入是三通道原图 + 三通道trimap_softmax图,输出是单通道的 α r α_r αr图;Fusion Module
:这部分的作用是得到精准的 α p α_p αp图。计算公式为: α p = F s + U s ∗ α r α_p = F_s + U_s*α_r αp=Fs+Us∗αr(前景图 + 不确定边界图 * 边界概率)
1.2、工程环境
复现代码链接在这里:Semantic Human Matting。环境如下:
Ubuntu 20.04.5 + python3.8.10
cuda 12.1 + cudnn 8.9
Pillow==10.1.0
torch==2.1.0+cu121
opencv-python==4.8.1.78
Windows我没有配置过,有兴趣的可以去试一试。
1.3、数据集
Semantic Human Matting工程主页作者给出了他找到的数据集,在这里对作者及爱分割公司表示感谢。数据集的百度网盘地址,密码:dzsn。
解压后可以看到其下主要包含两个文件夹:
- clip_img文件夹:其下都是三通道的原图;
- matting文件夹:其下都是与原图对应的mask图,仍需处理一下;
- 手动删除matting/1803201916/._matting_00000000这个错误文件;
- 手动修改clip_img/1803201916/clip_00000000/1803201916-00000117.png这个文件后缀,改成jpg后缀;
注意:整个数据集包含三万多张图片,预处理全部文件的话很耗时,所以在调试阶段建议选用其中某一个文件夹就行了。
在工程data
目录下新建matting
、clip_img
两个文件夹,再在原始数据集中挑选一组文件名一样的matting
、clip_img
文件放入其中;
二、数据处理
2.1、matting图生成mask图
先在data
文件夹下新建zcm_matting_get_mask.py
文件,代码如下:
import os
import cv2
matting_path = "matting/"
mask_path = "mask/"
# test
# for mask_name in os.listdir(matting_path):
# in_image = cv2.imread(matting_path + mask_name, cv2.IMREAD_UNCHANGED)
# alpha = in_image[:,:,3]
# cv2.imwrite(mask_path + mask_name, alpha)
for name_0 in os.listdir(matting_path):
if not os.path.exists(mask_path + "/" + name_0):
os.makedirs(mask_path + "/" + name_0)
for name_1 in os.listdir(matting_path + "/" + name_0):
if not os.path.exists(mask_path + name_0 + "/" + name_1):
os.mkdir(mask_path + name_0 + "/" + name_1)
for name_2 in os.listdir(matting_path + "/" + name_0 + "/" + name_1):
pic_input_path = matting_path + "/" + name_0 + "/" + name_1 + "/" + name_2
pic_output_path = mask_path + "/" + name_0 + "/" + name_1 + "/" + name_2
print("pic_input_path=", pic_input_path)
in_image = cv2.imread(pic_input_path, cv2.IMREAD_UNCHANGED)
alpha = in_image[:, :, 3]
cv2.imwrite(pic_output_path, alpha)
执行这个py文件,完成后可以在data
目录下看到生成了一个新的mask文件夹,mask文件夹下存储着黑白底的mask遮罩图。
2.2、生成训练目录:
在data
文件夹下新建zcm_get_train_txt.py
文件,代码如下:
import os
pic_path = "matting/"
with open("train.txt", "w", encoding="UTF-8") as ff:
for name_0 in os.listdir(pic_path):
for name_1 in os.listdir(pic_path + "/" + name_0):
for name_2 in os.listdir(pic_path + "/" + name_0 + "/" + name_1):
pic_input_path = name_0 + "/" + name_1 + "/" + name_2
ff.write(pic_input_path + "\n")
ff.close()
print("well done____________!")
执行这个py文件,完成后可以在data
目录下看到生成了一个新的train.txt
文件,打开里面存储训练数据的所有路径。
2.3、mask图生成trimap图
注释掉gen_trimap.py
第36/42/48行的断言语句
# assert(cnt1 == cnt2 + cnt3)
在gen_trimap.py
引入os库
import os
在gen_trimap.py
第64行后,添加如下代码;
trimap_name_1 = trimap_name.split("/")[:-1]
trimap_path = "/".join(trimap_name_1)
if not os.path.exists(trimap_path):
os.makedirs(trimap_path)
执行sh gen_trimap.sh
脚本,完成后在新生成的trimap
文件夹下保存着每张原图的三色图;(trimap三色图原意就是为了区分前景、背景、不确定区域)
2.4、生成alpha图
说明:这里给出两种生成alpha图的方法:
- 用工程自带的
knn_matting.sh
脚本生成alpha图; - 直接拷贝
mask
文件夹,将mask图作为alpha图注入训练;
第一种方法我在简单测试中使用过,该方法非常的耗时间,而且用该方法处理数据集得到alpha图将其注入训练,对最后预测的准确性的影响并不大;有兴趣的朋友可以对knn_matting继续改进,将时间效率提高;
我也阐述使用第二种方法的依据:因为爱分割公司提供的数据集,matting文件是他们人工扣出来的、是精确的,本质上等同于精分的alpha图。而knn_matting.sh
脚本存在的意义,是对于正常情况下,我们如使用像Faster-RCNN,DeepLab这样的分割算法得到的mask图是不精确的,才需要使用knn_matting算法处理不确定区域,进而得到精准的alpha图。
所以选用第二种方法,在data
文件夹下新建alpha
文件夹,执行复制语句,将mask
文件夹下所有文件复制到alpha
文件夹;
cp -r mask/* alpha/
至此,数据集处理完成。
三、训练
在Semantic Human Matting主目录下,新建train_code.txt
文件,写入如下指令:
T-net
python train.py --dataDir='./data' --saveDir='./ckpt' --trainData='human_matting_data' --trainList='./data/train_all.txt' --lrdecayType='keep' --nEpochs=200 --save_epoch=1 --load='human_matting' --patch_size=256 --lr=5e-5 --gpus='0,1' --nThreads=64 --train_batch=128 --train_phase='pre_train_t_net'
M-net
python train.py --dataDir='./data' --saveDir='./ckpt' --trainData='human_matting_data' --trainList='./data/train_all.txt' --lrdecayType='step' --nEpochs=400 --save_epoch=1 --load='human_matting' --patch_size=256 --lr=5e-5 --gpus='0,1' --nThreads=64 --train_batch=128 --train_phase='end_to_end'
第一段是T-Net
训练代码,第二段是M-Net
训练代码
3.2、修改train.py 文件:
在train.py
文件第29行后添加一条语句,用来指示GPU的使用情况
parser.add_argument('--gpus', default='0,1,2,3', help='gpus number')
3.3、修改dataset.py 文件:
用如下语句替换dataset.py
文件第17/18/19行
image_name = os.path.join(data_dir, 'clip_img', file_name['image'].replace("matting", "clip").replace("png", "jpg"))
trimap_name = os.path.join(data_dir, 'trimap', file_name['trimap'].replace("clip", "matting"))
alpha_name = os.path.join(data_dir, 'alpha', file_name['alpha'].replace("clip", "matting"))
用如下语句替换dataset.py
文件第101/102/103行:
trimap[trimap == 0] = 0
trimap[trimap >= 250] = 2
trimap[np.where(~((trimap == 0) | (trimap == 2)))] = 1
这里是整个代码中错误最隐蔽的一个,当初也是花了我很长时间才搞定。解释一下为什么这样做:我们知道trimap图是三色图,但是它的“三色”并不像上图中0/128/255只有这三色,它是在[0, 255]这个区间范围内。所以新改的代码,将这“三色”用区间区分,作为三种不同的label传入训练。
3.4、开启T-Net 训练:
运行train_code.txt
第一行代码,开启T-Net训练,如果你报内存不足的错误,就适当调小patch_size,nThreads,train_batch的数值;
python train.py --dataDir='./data' --saveDir='./ckpt' --trainData='human_matting_data' --trainList='./data/train_all.txt' --lrdecayType='keep' --nEpochs=200 --save_epoch=1 --load='human_matting' --patch_size=256 --lr=5e-5 --gpus='0,1' --nThreads=64 --train_batch=128 --train_phase='pre_train_t_net'
下图是T-Net 训练过程的loss变化,你也可以为得到更好的结果而增大nEpochs训练轮数;
3.5、开启M-Net 训练:
运行train_code.txt
第二行代码,开启M-Net 微调训练
python train.py --dataDir='./data' --saveDir='./ckpt' --trainData='human_matting_data' --trainList='./data/train_all.txt' --lrdecayType='step' --nEpochs=400 --save_epoch=1 --load='human_matting' --patch_size=256 --lr=5e-5 --gpus='0,1' --nThreads=64 --train_batch=128 --train_phase='end_to_end'
下图是我M-Net 训练过程的loss变化,你也可以为得到更好的结果而增大nEpochs训练轮数;
四、测试
新建test_camera_used.py
文件,写入如下代码,代码与test_camera.py
文件很相似,只是改了一部分需求,让过程更简洁;
在主目录下新建test_pic
文件夹,将测试所用的pic图片存入其中后,运行test_camera_used.py
文件,就能在result
文件夹下得到预测的结果图。
'''
test camera
Author: Zhengwei Li
Date : 2018/12/28
'''
import time
import cv2
import torch
import argparse
import numpy as np
import os
import torch.nn.functional as F
os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1, 2, 3'
parser = argparse.ArgumentParser(description='human matting')
parser.add_argument('--model', default='./ckpt/human_matting/model/model_obj.pth', help='preTrained model')
parser.add_argument('--size', type=int, default=320, help='input size')
parser.add_argument('--without_gpu', action='store_true', default=False, help='no use gpu')
args = parser.parse_args()
torch.set_grad_enabled(False)
#################################
#----------------
if args.without_gpu:
print("use CPU !")
device = torch.device('cpu')
else:
if torch.cuda.is_available():
n_gpu = torch.cuda.device_count()
print("----------------------------------------------------------")
print("| use GPU ! || Available GPU number is {} ! |".format(n_gpu))
print("----------------------------------------------------------")
device = torch.device('cuda: 0, 1, 2, 3')
#################################
#---------------
def load_model(args):
print('Loading model from {}...'.format(args.model))
if args.without_gpu:
myModel = torch.load(args.model, map_location=lambda storage, loc: storage)
else:
myModel = torch.load(args.model)
myModel.eval()
myModel.to(device)
# myModel.cuda()
return myModel
def seg_process(args, image, net):
# opencv
origin_h, origin_w, c = image.shape
image_resize = cv2.resize(image, (args.size,args.size), interpolation=cv2.INTER_CUBIC)
image_resize = (image_resize - (104., 112., 121.,)) / 255.0
tensor_4D = torch.FloatTensor(1, 3, args.size, args.size)
tensor_4D[0,:,:,:] = torch.FloatTensor(image_resize.transpose(2,0,1))
inputs = tensor_4D.to(device)
trimap, alpha = net(inputs)
trimap_np = trimap[0, 0, :, :].cpu().data.numpy()
trimap_np = cv2.resize(trimap_np, (origin_w, origin_h), interpolation=cv2.INTER_CUBIC)
mask_result = np.multiply(trimap_np[..., np.newaxis], image)
trimap_1 = mask_result.copy()
mask_result[trimap_1 < 10] = 255
mask_result[trimap_1 >= 10] = 0
cv2.imwrite("mask_result.png", mask_result)
if args.without_gpu:
alpha_np = alpha[0,0,:,:].data.numpy()
else:
alpha_np = alpha[0,0,:,:].cpu().data.numpy()
alpha_np = cv2.resize(alpha_np, (origin_w, origin_h), interpolation=cv2.INTER_CUBIC)
fg = np.multiply(alpha_np[..., np.newaxis], image)
# cv2.imwrite("fg.png", fg)
# bg = image
# bg_gray = np.multiply(1 - alpha_np[..., np.newaxis], image)
# bg_gray = cv2.cvtColor(bg_gray, cv2.COLOR_BGR2GRAY)
# # print("bg_gray=", bg_gray)
# bg[:,:,0] = bg_gray
# bg[:,:,1] = bg_gray
# bg[:,:,2] = bg_gray
#
# # fg[fg<=0] = 0
# # fg[fg>255] = 255
# # fg = fg.astype(np.uint8)
# # out = cv2.addWeighted(fg, 0.7, bg, 0.3, 0)
#
# # out = fg + bg
# # out[out<0] = 0
# # out[out>255] = 255
# # out = out.astype(np.uint8)
#
# out = fg.copy()
# out[out<10] = 0
# out[out>=10] = 255
# out = out.astype(np.uint8)
return fg, mask_result
def camera_seg(args, net):
# videoCapture = cv2.VideoCapture(0)
#
# while(1):
# # get a frame
# ret, frame = videoCapture.read()
# frame = cv2.flip(frame,1)
# frame_seg = seg_process(args, frame, net)
#
#
# # show a frame
# cv2.imshow("capture", frame_seg)
#
# if cv2.waitKey(1) & 0xFF == ord('q'):
# break
# videoCapture.release()
test_pic_path = "test_pic/"
output_path = "result/"
if not os.path.exists(output_path):
os.mkdir(output_path)
time_0 = time.time()
for name_ in os.listdir(test_pic_path):
frame = cv2.imread(test_pic_path + name_)
fg, mask_result = seg_process(args, frame, net)
print("SUCCESS_____!", test_pic_path + name_)
cv2.imwrite(output_path + name_.split(".")[0] + "_fg.jpg", fg)
cv2.imwrite(output_path + name_, mask_result)
print("time_all = ", time.time() - time_0)
def main(args):
time_1 = time.time()
myModel = load_model(args)
print("lodding_model_time = ", time.time() - time_1)
camera_seg(args, myModel)
if __name__ == "__main__":
main(args)
五、最后的说明
- 爱分割公司提供的数据集中,某一个目录中有一个没用的隐藏文件,如果不删除的话,数据准备过程、训练过程会报错——文件地址是:matting/1803201916/._matting_00000000
- 我训练了一个较好的model,用上了爱分割公司全部数据集 + 自建的一些数据集,展示测试的结果,左边是预测生成图,右边是原图;
- model的下载地址在本文最上面,绑定资源处;
- 有问题欢迎留言垂询;