本文作者为 360 奇舞团前端开发工程师
随着AI的火热发展,涌现了一些AI模特换装的前端工具(比如weshop网站),他们是怎么实现的呢?使用了什么技术呢?下文我们就来探索一下其实现原理。
总体的实现流程如下:我们将下图中的这个模特的图片,使用Segment Anything Model在后端分割图层,然后将分割后的图层mask信息返回给前端处理。在前端中选择需要保留的图层信息(如下图中的模特的衣服图层),然后将选中的图层信息交给后端中的Stable Diffusion处理。后端使用原始图片结合选中的图层蒙版图片结合图生图的功能,可以实现weshop等网站的模特换衣等功能。
本文先简单介绍一下使用SAM智能图层分割,然后主要介绍一下在前端中怎么对分割后的图层进行选择的处理流程。
使用SAM识别图层
首先我们需要对图层进行分割,在SAM出来之前,我们需要使用PS将模特的衣服选取出来,然后倒出衣服的模板,然后再使用其他工具进行替换。但是现在有了SAM后,我们可以对图片中的事物进去只能区分,获取各种物品的图层。
Segment Anything Model(SAM)是一种尖端的图像分割模型,可以进行快速分割,为图像分析任务提供无与伦比的多功能性。SAM 的先进设计使其能够在无需先验知识的情况下适应新的图像分布和任务,这一功能称为零样本传输。SAM 使任何人都可以在不依赖标记数据的情况下为其数据创建分段掩码。
要深入了解 Segment Anything 模型和 SA-1B 数据集,请访问Segment Anything 网站(https://segment-anything.com/)并查看研究论文Segment Anything(https://arxiv.org/abs/2304.02643)。
我们使用SAM进行图像分割,将一个图片中的物体分割成不同的部分。
def mask2rle(img):
'''
img: numpy array, 1 - mask, 0 - background
Returns run length as string formated
'''
pixels = img.T.flatten()
pixels = np.concatenate([[0], pixels, [0]])
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
runs[1::2] -= runs[::2]
return ' '.join(str(x) for x in runs)
def trans_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=False)
list = []
index = 0
# 对每个注释进行处理
for ann in sorted_anns:
bool_array = ann['segmentation']
# 将boolean类型的数组转换为int类型
int_array = bool_array.astype(int)
# 转化为RLE格式
rle = mask2rle(int_array)
list.append({"index": index, "mask": rle})
index += 1
return list
image = cv2.imread('<your image path>')
import sys
sys.path.append('<your segment-anything link path>')
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
# sam 模型路径
sam_checkpoint = '<your sam model path>'
# 根据下载的模型,设置对应的类型
model_type = "vit_h"
# device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
# sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)
# 处理sam返回的图层信息
mask_list = trans_anns(masks)
mask_obj = {
"height": image.shape[0],
"width": image.shape[1],
"mask_list": mask_list
}
import json
print(json.dumps(mask_obj))
运行以上python代码之前,需要配置sam的python环境,具体的配置描述请查看sam的官方描述。
我们通过以上代码,将我们