Text Image Inpainting via Global Structure-Guided Diffusion Models(GSDM)
这份文件是一篇关于文本图像修复(Text Image Inpainting)的研究论文,标题为“Text Image Inpainting via Global Structure-Guided Diffusion Models”,作者包括Shipeng Zhu、Pengfei Fang、Chenjie Zhu、Zuoyan Zhao、Qiang Xu、Hui Xue,来自中国东南大学计算机科学与工程学院以及教育部新一代人工智能技术及其交叉应用重点实验室。
研究背景
现实世界中的文本图像可能会因环境或人为因素而损坏,如腐蚀问题,这影响了文本的完整性,包括纹理和结构。这些问题给文本的理解和下游应用(如场景文本识别和签名识别)带来了挑战。
研究目标
本文旨在解决文本图像修复问题,建立基准数据集,并开发了一种新的神经网络框架——全局结构引导的扩散模型(Global Structure-guided Diffusion Model, GSDM)。
方法论
数据集构建:创建了两个特定的文本修复数据集,分别包含场景文本图像和手写文本图像,每个数据集包含原始图像、损坏图像以及其他辅助信息。
GSDM模型:提出了一种新的神经网络模型,利用文本的全局结构作为先验知识,通过高效的扩散模型恢复清晰的文本。
训练过程
GSDM中主要有两个模块,分别为SPM,RM。两个模型都是以U-Net方式构建,SPM主要是为了预测破损图像的前景图(带修复功能,可以理解为粗糙修复,输出结果为灰度图),RM的输入有三个分别为 SPM的输出(由于是灰度图需要在Channel扩充到3维),破损图,以及DDPM的在T步加噪图像。三个输入concat输入Unet进行修复。训练过程复现过程:
原始论文的代码: 论文github链接,向大佬🫡
修改后的文件目录:
下面展示 util.py
。
import os
import torch
import math
from torchvision import transforms
import random
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
from torchvision.utils import make_grid
import cv2
IMG_EXTENSIONS = ['jpg', 'JPG', 'jpeg', 'JPEG',
'png', 'PNG', 'ppm', 'PPM', 'bmp', 'BMP']
def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1)):
'''
Converts a torch Tensor into an image Numpy array
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
'''
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp
tensor = (tensor - min_max[0]) / \
(min_max[1] - min_max[0]) # to range [0,1]
n_dim = tensor.dim()
if n_dim == 4:
n_img = len(tensor)
img_np = make_grid(tensor, nrow=int(
math.sqrt(n_img)), normalize=False).numpy()
img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB
elif n_dim == 3:
img_np = tensor.numpy()
img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB
elif n_dim == 2:
img_np = tensor.numpy()
else:
raise TypeError(
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
if out_type == np.uint8:
img_np = (img_np * 255.0).round()
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
return img_np.astype(out_type)
def save_img(img, img_path):
cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
class val_dataset(Dataset):
# Define the val dataset
def __init__(
self,
input_dir,
data_shape
):
super().__init__()
self.input_dir = input_dir
self.img_list = [file_name for file_name in os.listdir(self.input_dir) if file_name.split('.')[-1] in IMG_EXTENSIONS]
self.resolution = data_shape
self.img_name = ""
self.img_trans = transforms.Compose([transforms.Resize(self.resolution),
transforms.ToTensor(),
# transforms.Normalize(self.opt.DATASET.MEAN, self.opt.DATASET.STD)
])
def __len__(self):
return len(self.img_list)
def __getitem__(self, idx):
# If there is a problem with an image in the data set, read the next one
try:
self.img_name = self.img_list[idx]
img_input = Image.open(os.path.join(self.input_dir, self.img_name)).convert('RGB')
except:
self.img_name = self.img_list[idx+1]
img_input = Image.open(os.path.join(self.input_dir, self.img_name)).convert('RGB')
return {
"image": self.img_trans(img_input), "name": self.img_name}
def save_sp(input:torch._tensor, save_dir:str):
toPIL = transforms.ToPILImage()
input = input/2+1
return toPIL(input.detach().cpu().squeeze()).save(save_dir)
def gray2bgr(input:torch._tensor):
return torch.cat((input, input, input), dim=1)
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def get_paths_from_images(path):
assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
images = []
for dirpath, _, fnames in sorted(os.walk(path)):
for fname in sorted(fnames):
if is_image_file(fname):
img_path = os.path.join(dirpath, fname)
images.append(img_path)
assert images, '{:s} has no valid image file'.format