利用clip模型实现text2draw

参考论文

实践

有数据增强的代码

import math
import collections
import CLIP_.clip as clip
import torch
import torch.nn as nn
from torchvision import models, transforms
import numpy as np
import webp
from PIL import Image
import skimage
import torchvision
import pydiffvg
import os
import torch.nn.functional as F

class GeometrymatchLoss(torch.nn.Module):
    def __init__(self, device, reference_images_path):
        super(GeometrymatchLoss, self).__init__()
        self.device = device
        self.model, clip_preprocess = clip.load(
            'ViT-B/32', self.device, jit=False)
        self.model.eval()
        self.preprocess = transforms.Compose(
            [clip_preprocess.transforms[0], clip_preprocess.transforms[-1]])  # clip normalisation
        self.reference_images_feature = self.reference_images_feature(reference_images_path)
        self.reference_images_feature =self.reference_images_feature/ self.reference_images_feature.norm(dim=-1, keepdim=True)
        self.text = clip.tokenize([ "A picture of triangle"]).to(device)
        self.text_features = self.model.encode_text(self.text)
        # self.text_features = self.text_features / self.text_features.norm(dim=-1, keepdim=True)
        print("text_features.requires_grad:",self.text_features.requires_grad)
        self.text_features=self.text_features.detach()
        self.shape_groups=[pydiffvg.ShapeGroup(shape_ids=torch.tensor([0]), fill_color=torch.tensor([0.0, 0.0, 0.0, 1.0]),
                                       stroke_color=torch.tensor([0.0, 0.0, 0.0, 1.0]))]

        # Image Augmentation Transformation
        self.augment_trans = transforms.Compose([
            transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5),
            transforms.RandomResizedCrop(224, scale=(0.7, 0.9)),
        ])



    def forward(self, t,canvas_width, canvas_height,shapes):

        scene_args = pydiffvg.RenderFunction.serialize_scene(canvas_width, canvas_height, shapes, self.shape_groups)
        # 渲染图像
        render = pydiffvg.RenderFunction.apply
        target = render(canvas_width, canvas_height, 2, 2, 0, None, *scene_args)

        if target.shape[-1] == 4:
            target = self.compose_image_with_white_background(target)
        if t%100==0:
            pydiffvg.imwrite(target.cpu(), f'learn/log_augs/output_{t}.png', gamma=2.2)
        # targets_ = self.preprocess(target.permute(2, 0, 1).unsqueeze(0)).to(self.device)
        img = target.unsqueeze(0)
        img = img.permute(0, 3, 1, 2)
        loss = 0
        NUM_AUGS = 4
        img_augs = []
        for n in range(NUM_AUGS):
            img_augs.append(self.augment_trans(img))
        im_batch = torch.cat(img_augs)
        image_features = self.model.encode_image(im_batch)
        # logit_scale = self.model.logit_scale.exp()
        for n in range(NUM_AUGS):
            loss -= torch.cosine_similarity(self.text_features, image_features[n:n + 1], dim=1)
        return loss


    def compose_image_with_white_background(self, img: torch.tensor) -> torch.tensor:
        if img.shape[-1] == 3:  # return img if it is already rgb
            return img
        # Compose img with white background
        alpha = img[:, :, 3:4]
        img = alpha * img[:, :, :3] + (1 - alpha) * torch.ones(
            img.shape[0], img.shape[1], 3, device=self.device)
        return img

    def read_png_image_from_path(self, path_to_png_image: str) -> torch.tensor:
        numpy_image = skimage.io.imread(path_to_png_image)
        normalized_tensor_image = torch.from_numpy(numpy_image).to(
            torch.float32) / 255.0

        resizer = torchvision.transforms.Resize((224, 224))
        resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)
                                ).permute(1, 2, 0)
        return resized_image

    def reference_images_feature(self, reference_images_path):
        reference_images_num = len(os.listdir(reference_images_path))
        reference_images_feature = []
        for i in range(reference_images_num):
            i_reference_image = self.read_png_image_from_path(os.path.join(reference_images_path, str(i) + ".png"))
            if i_reference_image.shape[-1] == 4:
                i_reference_image = self.compose_image_with_white_background(i_reference_image)
            # targets_ = self.preprocess(i_reference_image.permute(2, 0, 1).unsqueeze(0)).to(self.device)
            i_reference_image_features = self.model.encode_image(i_reference_image.permute(2, 0, 1).unsqueeze(0).to(self.device)).detach()
            reference_images_feature.append(i_reference_image_features)
        return torch.cat(reference_images_feature)


def read_png_image_from_path(path_to_png_image: str) -> torch.tensor:
    if path_to_png_image.endswith('.webp'):
        numpy_image = np.array(webp.load_image(path_to_png_image))
    else:
        numpy_image = skimage.io.imread(path_to_png_image)
    normalized_tensor_image = torch.from_numpy(numpy_image).to(
        torch.float32) / 255.0

    resizer = torchvision.transforms.Resize((224, 224))
    resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)
                            ).permute(1, 2, 0)
    return resized_image


if __name__ == '__main__':
    torch.autograd.set_detect_anomaly(True)
    from tqdm import tqdm
    def get_bezier_circle(radius: float = 80,
                          segments: int = 4,
                          bias: np.array = np.asarray([100., 100.])):
        deg = torch.arange(0, segments * 3 + 1) * 2 * np.pi / (segments * 3 + 1)
        points = torch.stack((torch.cos(deg), torch.sin(deg))).T
        points = points * radius + torch.tensor(bias).unsqueeze(dim=0)
        points = points.type(torch.FloatTensor).contiguous()
        return points
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    matchLoss = GeometrymatchLoss(device, "reference_images/")
    # print(matchLoss.reference_images_feature.shape)
    # img1 = read_png_image_from_path('learn/output.png')
    canvas_width, canvas_height = 224, 224
    num_segments=4

    points1 = get_bezier_circle()

    path = pydiffvg.Path(num_control_points=torch.tensor(num_segments * [2] + [0],dtype=torch.int32), points=points1, stroke_width=torch.tensor(2.0),
                          is_closed=True)
    shapes=[path]
    path.points.requires_grad = True
    print(id(path.points))
    print(id(points1))
    points_vars = []
    points_vars.append(path.points)
    points_optim = torch.optim.Adam(points_vars, lr=1)
    pbar = tqdm(range(100000))
    print(points1)
    for t in pbar:
        # print(t)
        points_optim.zero_grad()
        # print("match_loss:", match_loss)
        match_loss = matchLoss(t,224, 224, shapes)

        match_loss.backward()
        # print(path.points.grad)
        points_optim.step()
        pbar.set_postfix({"match_loss": f"{match_loss.item()}"})
        # print(points_vars[0])


    pass

迭代1000轮次后生成的结果
在这里插入图片描述

没有图像增强

import math
import collections
import CLIP_.clip as clip
import torch
import torch.nn as nn
from torchvision import models, transforms
import numpy as np
import webp
from PIL import Image
import skimage
import torchvision
import pydiffvg
import os
import torch.nn.functional as F

class GeometrymatchLoss(torch.nn.Module):
    def __init__(self, device, reference_images_path):
        super(GeometrymatchLoss, self).__init__()
        self.device = device
        self.model, clip_preprocess = clip.load(
            'ViT-B/32', self.device, jit=False)
        self.model.eval()
        self.preprocess = transforms.Compose(
            [clip_preprocess.transforms[0], clip_preprocess.transforms[-1]])  # clip normalisation
        # self.preprocess = transforms.Compose([clip_preprocess.transforms[-1]])  # clip normalisation
        self.reference_images_feature = self.reference_images_feature(reference_images_path)
        self.reference_images_feature =self.reference_images_feature/ self.reference_images_feature.norm(dim=-1, keepdim=True)
        self.text = clip.tokenize([ "A picture of triangle"]).to(device)
        # self.text = clip.tokenize(["A picture of rectangle", "A picture of triangle", "A picture of circle", "A picture of pentagon","A picture of five-pointed star"]).to(device)
        self.text_features = self.model.encode_text(self.text)
        self.text_features = self.text_features / self.text_features.norm(dim=-1, keepdim=True)
        print("text_features.requires_grad:",self.text_features.requires_grad)
        self.text_features=self.text_features.detach()
        self.shape_groups=[pydiffvg.ShapeGroup(shape_ids=torch.tensor([0]), fill_color=torch.tensor([0.0, 0.0, 0.0, 1.0]),
                                       stroke_color=torch.tensor([0.0, 0.0, 0.0, 1.0]))]

        # Image Augmentation Transformation
        self.augment_trans = transforms.Compose([
            transforms.RandomPerspective(fill=1, p=1, distortion_scale=0.5),
            transforms.RandomResizedCrop(224, scale=(0.7, 0.9)),
        ])



    def forward(self, t,canvas_width, canvas_height,shapes):

        scene_args = pydiffvg.RenderFunction.serialize_scene(canvas_width, canvas_height, shapes, self.shape_groups)
        # 渲染图像
        render = pydiffvg.RenderFunction.apply
        target = render(canvas_width, canvas_height, 2, 2, 0, None, *scene_args)

        if target.shape[-1] == 4:
            target = self.compose_image_with_white_background(target)
        if t%100==0:
            pydiffvg.imwrite(target.cpu(), f'learn/log/output_{t}.png', gamma=2.2)
        # targets_ = self.preprocess(target.permute(2, 0, 1).unsqueeze(0)).to(self.device)
        img = target.unsqueeze(0)
        img = img.permute(0, 3, 1, 2)
        loss = 0
        NUM_AUGS = 4
        img_augs = []
        for n in range(NUM_AUGS):
            img_augs.append(self.augment_trans(img))
        im_batch = torch.cat(img_augs)
        image_features = self.model.encode_image(img)
        self.targets_features: torch.tensor=image_features[0]
        self.targets_features = self.targets_features / self.targets_features.norm(dim=-1, keepdim=True)
        loss -= torch.cosine_similarity(self.text_features, self.targets_features, dim=1)

        return loss


    def compose_image_with_white_background(self, img: torch.tensor) -> torch.tensor:
        if img.shape[-1] == 3:  # return img if it is already rgb
            return img
        # Compose img with white background
        alpha = img[:, :, 3:4]
        img = alpha * img[:, :, :3] + (1 - alpha) * torch.ones(
            img.shape[0], img.shape[1], 3, device=self.device)
        return img

    def read_png_image_from_path(self, path_to_png_image: str) -> torch.tensor:
        numpy_image = skimage.io.imread(path_to_png_image)
        normalized_tensor_image = torch.from_numpy(numpy_image).to(
            torch.float32) / 255.0

        resizer = torchvision.transforms.Resize((224, 224))
        resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)
                                ).permute(1, 2, 0)
        return resized_image

    def reference_images_feature(self, reference_images_path):
        reference_images_num = len(os.listdir(reference_images_path))
        reference_images_feature = []
        for i in range(reference_images_num):
            i_reference_image = self.read_png_image_from_path(os.path.join(reference_images_path, str(i) + ".png"))
            if i_reference_image.shape[-1] == 4:
                i_reference_image = self.compose_image_with_white_background(i_reference_image)
            # targets_ = self.preprocess(i_reference_image.permute(2, 0, 1).unsqueeze(0)).to(self.device)
            i_reference_image_features = self.model.encode_image(i_reference_image.permute(2, 0, 1).unsqueeze(0).to(self.device)).detach()
            reference_images_feature.append(i_reference_image_features)
        return torch.cat(reference_images_feature)


def read_png_image_from_path(path_to_png_image: str) -> torch.tensor:
    if path_to_png_image.endswith('.webp'):
        numpy_image = np.array(webp.load_image(path_to_png_image))
    else:
        numpy_image = skimage.io.imread(path_to_png_image)
    normalized_tensor_image = torch.from_numpy(numpy_image).to(
        torch.float32) / 255.0

    resizer = torchvision.transforms.Resize((224, 224))
    resized_image = resizer(normalized_tensor_image.permute(2, 0, 1)
                            ).permute(1, 2, 0)
    return resized_image


if __name__ == '__main__':
    torch.autograd.set_detect_anomaly(True)
    from tqdm import tqdm
    def get_bezier_circle(radius: float = 80,
                          segments: int = 4,
                          bias: np.array = np.asarray([100., 100.])):
        deg = torch.arange(0, segments * 3 + 1) * 2 * np.pi / (segments * 3 + 1)
        points = torch.stack((torch.cos(deg), torch.sin(deg))).T
        points = points * radius + torch.tensor(bias).unsqueeze(dim=0)
        points = points.type(torch.FloatTensor).contiguous()
        return points
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    matchLoss = GeometrymatchLoss(device, "reference_images/")
    # print(matchLoss.reference_images_feature.shape)
    # img1 = read_png_image_from_path('learn/output.png')
    canvas_width, canvas_height = 224, 224
    num_segments=4

    points1 = get_bezier_circle()

    path = pydiffvg.Path(num_control_points=torch.tensor(num_segments * [2] + [0],dtype=torch.int32), points=points1, stroke_width=torch.tensor(2.0),
                          is_closed=True)
    shapes=[path]
    path.points.requires_grad = True
    print(id(path.points))
    print(id(points1))
    points_vars = []
    points_vars.append(path.points)
    points_optim = torch.optim.Adam(points_vars, lr=1)
    pbar = tqdm(range(100000))
    print(points1)
    for t in pbar:
        # print(t)
        points_optim.zero_grad()
        # print("match_loss:", match_loss)
        match_loss = matchLoss(t,224, 224, shapes)

        match_loss.backward()
        # print(path.points.grad)
        points_optim.step()
        pbar.set_postfix({"match_loss": f"{match_loss.item()}"})
        # print(points_vars[0])


    pass

迭代1000轮次后生成的结果
在这里插入图片描述
迭代2000轮次后生成的结果
在这里插入图片描述
迭代4000轮次后生成的结果
在这里插入图片描述
迭代8000轮次后生成的结果
在这里插入图片描述

无图像增强效果不好的原因分析

论文CLIPDraw: Exploring Text-to-Drawing Synthesisthrough Language-Image Encoders解释

在这里插入图片描述

论文StyleCLIPDraw: Coupling Content and Style in Text-to-Drawing Translation解释

在这里插入图片描述

个人理解

因为有很多图片可以和一个文本相匹配,对于我们人来说这些图片有一个根本和文本不相关,如果进行图像增强大概率会得到局部最优值。在计算损失函数之前对图片先进行增强,透过透视等变换,相关的图片不论如何变换和文本的相似度基本不会降低,而不相关的图像变换完之后一般会让相似度降低,这样就可以防止不相关图片对实验结果的影响。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值