document image inpaint


import cv2
import torch
import torch.nn.functional as F
import numpy as np
import time


device = torch.device("cuda:0")


def torch_inpaint(pil_img, mask):
    start = time.time()
    r = mask[:, 0:1, :, :]  # r 是手写体
    g = mask[:, 1:2, :, :]  # g 是印刷体
    b = mask[:, 2:3, :, :]  # b 是背景
    mask = r & torch.bitwise_not(g)

    # dialte 手写体 mask
    dil_kernel = torch.Tensor([[[[0,1,0], [1,1,1], [0,1,0]]]]).to(device=device)
    mask = F.conv2d(mask.type(torch.float32), dil_kernel, padding=1)  # x1
    mask = F.conv2d(mask.type(torch.float32), dil_kernel, padding=1)  # x2


    mask = (mask > 0)

    mask = mask.repeat(1, 3, 1, 1)
    mask_g = g.repeat(1, 3, 1, 1)
    mask_b = b.repeat(1, 3, 1, 1)


    # TODO:
    ### mask 膨胀一圈之后减去 mask,这部分和 背景取 & 作为平均,分区域做这步操作
    tt = torch.mean(pil_img[mask_b])  # 应该选取周围的背景色而不是所有的背景色的平均

    pil_img[mask] = 0
    pil_img_tmp = pil_img.clone()
    pil_img_tmp[mask_g] = 0

    mask_g = mask_g.bool()

    box_ks = 15
    filters = torch.ones((3, 1, box_ks, box_ks), dtype=torch.float32, device=device)
    blurred = F.conv2d(pil_img_tmp, filters, padding=(box_ks-1)//2, groups=3)
    countMask = F.conv2d(((mask|mask_g) == 0).type(torch.float32), filters, padding=(box_ks-1)//2, groups=3)
    # print('[filters, blurred, countMask] shape===', filters.shape, blurred.shape, countMask.shape)

    eps = 1e-8
    pil_img[mask] = blurred[mask] / (countMask[mask]+eps)
    pil_img[(mask & (pil_img == 0))] = tt

    # ori_mask = ori_mask.squeeze().cpu().numpy().transpose(1,2,0)
    pil_img = pil_img.squeeze().cpu().numpy().transpose(1,2,0).astype(np.uint8)
    print(time.time() - start)
    return pil_img


if __name__ == '__main__':

    ### 最终的结果再加个这个

    img = cv2.imread("./331800.jpg")
    mask = cv2.imread("./331800.png", 0)
    # img = img[:100, 300:500, :]
    # mask = mask[:100, 300:500]



    mask_one_hot = np.zeros((mask.shape[0], mask.shape[1], 3),dtype=np.uint8)

    #  0-背景 1-印刷体 2-手写体
    mask_one_hot[:, :, 0][mask == 2] = 1
    mask_one_hot[:, :, 1][mask == 1] = 1
    mask_one_hot[:, :, 2][mask == 0] = 1

    im_tensor = torch.tensor([img], dtype=torch.float32)
    mask_tensor = torch.tensor([mask_one_hot])
    #
    im_tensor = torch.transpose(im_tensor, 1,3)
    im_tensor = torch.transpose(im_tensor, 2,3).cuda()
    mask_tensor = torch.transpose(mask_tensor, 1,3)
    mask_tensor = torch.transpose(mask_tensor, 2,3).cuda()
    #
    output_im = torch_inpaint(im_tensor, mask_tensor)
    cv2.imwrite("output_im.png", output_im)


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值