Text Image Inpainting via Global Structure-Guided Diffusion Models(GSDM) 训练复现

Text Image Inpainting via Global Structure-Guided Diffusion Models(GSDM)

这份文件是一篇关于文本图像修复(Text Image Inpainting)的研究论文,标题为“Text Image Inpainting via Global Structure-Guided Diffusion Models”,作者包括Shipeng Zhu、Pengfei Fang、Chenjie Zhu、Zuoyan Zhao、Qiang Xu、Hui Xue,来自中国东南大学计算机科学与工程学院以及教育部新一代人工智能技术及其交叉应用重点实验室。

研究背景

现实世界中的文本图像可能会因环境或人为因素而损坏,如腐蚀问题,这影响了文本的完整性,包括纹理和结构。这些问题给文本的理解和下游应用(如场景文本识别和签名识别)带来了挑战。
任务表示

研究目标

本文旨在解决文本图像修复问题,建立基准数据集,并开发了一种新的神经网络框架——全局结构引导的扩散模型(Global Structure-guided Diffusion Model, GSDM)。
网络结构图

方法论

数据集构建:创建了两个特定的文本修复数据集,分别包含场景文本图像和手写文本图像,每个数据集包含原始图像、损坏图像以及其他辅助信息。
GSDM模型:提出了一种新的神经网络模型,利用文本的全局结构作为先验知识,通过高效的扩散模型恢复清晰的文本。

训练过程

GSDM中主要有两个模块,分别为SPM,RM。两个模型都是以U-Net方式构建,SPM主要是为了预测破损图像的前景图(带修复功能,可以理解为粗糙修复,输出结果为灰度图),RM的输入有三个分别为 SPM的输出(由于是灰度图需要在Channel扩充到3维),破损图,以及DDPM的在T步加噪图像。三个输入concat输入Unet进行修复。训练过程复现过程:
原始论文的代码: 论文github链接,向大佬🫡
修改后的文件目录:
请添加图片描述
下面展示 util.py

import os
import torch
import math
from torchvision import transforms
import random
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
from torchvision.utils import make_grid
import cv2

IMG_EXTENSIONS = ['jpg', 'JPG', 'jpeg', 'JPEG',
                  'png', 'PNG', 'ppm', 'PPM', 'bmp', 'BMP']


def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1)):
    '''
    Converts a torch Tensor into an image Numpy array
    Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
    Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
    '''
    tensor = tensor.squeeze().float().cpu().clamp_(*min_max)  # clamp
    tensor = (tensor - min_max[0]) / \
        (min_max[1] - min_max[0])  # to range [0,1]
    n_dim = tensor.dim()
    if n_dim == 4:
        n_img = len(tensor)
        img_np = make_grid(tensor, nrow=int(
            math.sqrt(n_img)), normalize=False).numpy()
        img_np = np.transpose(img_np, (1, 2, 0))  # HWC, RGB
    elif n_dim == 3:
        img_np = tensor.numpy()
        img_np = np.transpose(img_np, (1, 2, 0))  # HWC, RGB
    elif n_dim == 2:
        img_np = tensor.numpy()
    else:
        raise TypeError(
            'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
    if out_type == np.uint8:
        img_np = (img_np * 255.0).round()
        # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
    return img_np.astype(out_type)


def save_img(img, img_path):
    cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))

class val_dataset(Dataset):
    # Define the val dataset
    def __init__(
        self,
        input_dir,
        data_shape
    ):
        super().__init__()
        self.input_dir = input_dir
        self.img_list = [file_name for file_name in os.listdir(self.input_dir) if file_name.split('.')[-1] in IMG_EXTENSIONS]
        self.resolution = data_shape
        self.img_name = ""
        self.img_trans = transforms.Compose([transforms.Resize(self.resolution),
                                         transforms.ToTensor(),
                                         # transforms.Normalize(self.opt.DATASET.MEAN, self.opt.DATASET.STD)
                                         ])

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        # If there is a problem with an image in the data set, read the next one
        try:
            self.img_name = self.img_list[idx]
            img_input = Image.open(os.path.join(self.input_dir, self.img_name)).convert('RGB')
        except:
            self.img_name = self.img_list[idx+1]
            img_input = Image.open(os.path.join(self.input_dir, self.img_name)).convert('RGB')
        return {
   "image": self.img_trans(img_input), "name": self.img_name}

def save_sp(input:torch._tensor, save_dir:str):
    toPIL = transforms.ToPILImage()
    input = input/2+1
    return toPIL(input.detach().cpu().squeeze()).save(save_dir)

def gray2bgr(input:torch._tensor):
    return torch.cat((input, input, input), dim=1)

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def get_paths_from_images(path):
    assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
    images = []
    for dirpath, _, fnames in sorted(os.walk(path)):
        for fname in sorted(fnames):
            if is_image_file(fname):
                img_path = os.path.join(dirpath, fname)
                images.append(img_path)
    assert images, '{:s} has no valid image file'.format
  • 26
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值