(5-3-03)常用的文生图预训练模型:基于VQGAN+CLIP的文生图系统(2)生成图像

4. 生成图像

文件generate.py作为项目的核心文件,能够根据命令参数生成图像,并将其编排成视频来实现图像生成任务。文件包含以下四个主要部分:

(1)优化器配置 (get_opt 函数):根据传入的优化器名称(如 Adam、AdamW 等)和学习率,选择并返回相应的优化器实例。

(2)参数和设置初始化:读取用户提供的设备、优化器、提示、初始图像等设置,并进行随机种子的初始化,以确保结果的可重复性。

(3)图像生成与训练过程:初始化迭代计数器、视频帧计数器等,进入主训练循环。该部分负责生成图像、应用缩放和移动效果、更新文本提示,并逐步训练模型。图像被保存和处理,以便生成最终的视频输出。

(4)视频生成:将生成的图像序列合成视频,支持使用硬件编码进行高效视频处理。根据设定的帧率和视频长度,将图像保存为视频文件,并进行适当的插值处理以提高视频质量。

整体而言,文件generate.py的功能是实现从图像生成到视频合成的完整流程,通过训练模型生成图像,并将其编排成高质量的视频。文件generate.py的具体实现流程如下所示。

(1)导入所需的库和模块,设置命令行参数解析,处理用户输入的文本和图像提示,配置 GPU 和图像大小,并准备生成图像的环境。通过这些设置,代码为使用 VQGAN+CLIP 进行图像生成的过程奠定了基础,确保能够根据用户提供的提示和参数生成高质量的图像。

import argparse  # 用于处理命令行参数
import math  # 数学运算
import random  # 随机数生成
# from email.policy import default  # 注释掉的导入
from urllib.request import urlopen  # 用于打开 URL
from tqdm import tqdm  # 进度条显示
import sys  # 系统参数和函数
import os  # 文件和目录操作

# pip install taming-transformers 不支持 Gumbel,但尚未支持 coco 等
# 添加路径以支持 Gumbel,但会导致 ModuleNotFoundError: No module named 'transformers' 
sys.path.append('taming-transformers')

from omegaconf import OmegaConf  # 用于处理配置文件
from taming.models import cond_transformer, vqgan  # 导入 VQGAN 模型

import torch  # 导入 PyTorch
from torch import nn, optim  # 导入神经网络和优化器模块
from torch.nn import functional as F  # 导入功能性神经网络模块
from torchvision import transforms  # 图像转换模块
from torchvision.transforms import functional as TF  # 图像功能性转换模块
from torch.cuda import get_device_properties  # 获取 GPU 设备属性
torch.backends.cudnn.benchmark = False  # NR: 设置为 True 可以稍微加快速度,但可能导致 OOM(内存溢出)。设置为 False 更具确定性。
# torch.use_deterministic_algorithms(True)  # NR: grid_sampler_2d_backward_cuda 没有确定性实现

from torch_optimizer import DiffGrad, AdamP  # 导入自定义优化器

from CLIP import clip  # 导入 CLIP 模型
import kornia.augmentation as K  # 图像增强库
import numpy as np  # 数组处理
import imageio  # 图像输入输出库

from PIL import ImageFile, Image, PngImagePlugin, ImageChops  # PIL 图像处理库
ImageFile.LOAD_TRUNCATED_IMAGES = True  # 允许加载截断图像

from subprocess import Popen, PIPE  # 用于创建子进程
import re  # 正则表达式处理

# 抑制警告
import warnings
warnings.filterwarnings('ignore')  # 忽略警告信息

# 检查 GPU 并根据 VRAM 大小减小默认图像大小
default_image_size = 512  # >8GB VRAM
if not torch.cuda.is_available():  # 检查是否有 GPU 可用
    default_image_size = 256  # 找不到 GPU
elif get_device_properties(0).total_memory <= 2 ** 33:  # 2 ** 33 = 8,589,934,592 字节 = 8 GB
    default_image_size = 304  # <8GB VRAM

# 创建参数解析器
vq_parser = argparse.ArgumentParser(description='使用 VQGAN+CLIP 进行图像生成')

# 添加参数
vq_parser.add_argument("-p",    "--prompts", type=str, help="文本提示", default=None, dest='prompts')
vq_parser.add_argument("-ip",   "--image_prompts", type=str, help="图像提示 / 目标图像", default=[], dest='image_prompts')
vq_parser.add_argument("-i",    "--iterations", type=int, help="迭代次数", default=500, dest='max_iterations')
vq_parser.add_argument("-se",   "--save_every", type=int, help="每多少次保存一次图像", default=50, dest='display_freq')
vq_parser.add_argument("-s",    "--size", nargs=2, type=int, help="图像大小(宽 高) (默认: %(default)s)", default=[default_image_size, default_image_size], dest='size')
vq_parser.add_argument("-ii",   "--init_image", type=str, help="初始图像", default=None, dest='init_image')
vq_parser.add_argument("-in",   "--init_noise", type=str, help="初始噪声图像(像素或梯度)", default=None, dest='init_noise')
vq_parser.add_argument("-iw",   "--init_weight", type=float, help="初始权重", default=0., dest='init_weight')
vq_parser.add_argument("-m",    "--clip_model", type=str, help="CLIP 模型(例如 ViT-B/32, ViT-B/16)", default='ViT-B/32', dest='clip_model')
vq_parser.add_argument("-conf", "--vqgan_config", type=str, help="VQGAN 配置", default=f'checkpoints/vqgan_imagenet_f16_16384.yaml', dest='vqgan_config')
vq_parser.add_argument("-ckpt", "--vqgan_checkpoint", type=str, help="VQGAN 检查点", default=f'checkpoints/vqgan_imagenet_f16_16384.ckpt', dest='vqgan_checkpoint')
vq_parser.add_argument("-nps",  "--noise_prompt_seeds", nargs="*", type=int, help="噪声提示种子", default=[], dest='noise_prompt_seeds')
vq_parser.add_argument("-npw",  "--noise_prompt_weights", nargs="*", type=float, help="噪声提示权重", default=[], dest='noise_prompt_weights')
vq_parser.add_argument("-lr",   "--learning_rate", type=float, help="学习率", default=0.1, dest='step_size')
vq_parser.add_argument("-cutm", "--cut_method", type=str, help="切割方法", choices=['original','updated','nrupdated','updatedpooling','latest'], default='latest', dest='cut_method')
vq_parser.add_argument("-cuts", "--num_cuts", type=int, help="切割数量", default=32, dest='cutn')
vq_parser.add_argument("-cutp", "--cut_power", type=float, help="切割幂", default=1., dest='cut_pow')
vq_parser.add_argument("-sd",   "--seed", type=int, help="种子", default=None, dest='seed')
vq_parser.add_argument("-opt",  "--optimiser", type=str, help="优化器", choices=['Adam','AdamW','Adagrad','Adamax','DiffGrad','AdamP','RAdam','RMSprop'], default='Adam', dest='optimiser')
vq_parser.add_argument("-o",    "--output", type=str, help="输出图像文件名", default="output.png", dest='output')
vq_parser.add_argument("-vid",  "--video", action='store_true', help="创建视频帧?", dest='make_video')
vq_parser.add_argument("-zvid", "--zoom_video", action='store_true', help="创建缩放视频?", dest='make_zoom_video')
vq_parser.add_argument("-zs",   "--zoom_start", type=int, help="缩放开始迭代", default=0, dest='zoom_start')
vq_parser.add_argument("-zse",  "--zoom_save_every", type=int, help="保存缩放图像迭代", default=10, dest='zoom_frequency')
vq_parser.add_argument("-zsc",  "--zoom_scale", type=float, help="缩放比例 %%", default=0.99, dest='zoom_scale')
vq_parser.add_argument("-zsx",  "--zoom_shift_x", type=int, help="缩放 x 轴偏移(左右)像素数", default=0, dest='zoom_shift_x')
vq_parser.add_argument("-zsy",  "--zoom_shift_y", type=int, help="缩放 y 轴偏移(上下)像素数", default=0, dest='zoom_shift_y')
vq_parser.add_argument("-cpe",  "--change_prompt_every", type=int, help="提示更改频率", default=0, dest='prompt_frequency')
vq_parser.add_argument("-vl",   "--video_length", type=float, help="视频长度(秒,未插值)", default=10, dest='video_length')
vq_parser.add_argument("-ofps", "--output_video_fps", type=float, help="创建插值视频(仅限 Nvidia GPU),帧率(最小 10,最好设置为 30 或 60)", default=0, dest='output_video_fps')
vq_parser.add_argument("-ifps", "--input_video_fps", type=float, help="创建插值视频时,使用此输入帧率进行插值(>0 & <ofps)", default=15, dest='input_video_fps')
vq_parser.add_argument("-d",    "--deterministic", action='store_true', help="启用 cudnn.deterministic?", dest='cudnn_determinism')
vq_parser.add_argument("-aug",  "--augments", nargs='+', action='append', type=str, choices=['Ji','Sh','Gn','Pe','Ro','Af','Et','Ts','Cr','Er','Re'], help="启用增强(仅限最新的 vut 方法)", default=[], dest='augments')
vq_parser.add_argument("-vsd",  "--video_style_dir", type=str, help="用于风格化的视频帧目录", default=None, dest='video_style_dir')
vq_parser.add_argument("-cd",   "--cuda_device", type=str, help="要使用的 CUDA 设备", default="cuda:0", dest='cuda_device')

# 执行 parse_args() 方法
args = vq_parser.parse_args()

# 如果没有提供 prompts 和 image_prompts,则使用默认提示
if not args.prompts and not args.image_prompts:
    args.prompts = "一只可爱、微笑的、呆萌的啮齿动物"

# 如果启用了确定性,则设置 cudnn 为确定性
if args.cudnn_determinism:
   torch.backends.cudnn.deterministic = True

# 如果没有增强参数,则使用默认增强
if not args.augments:
   args.augments = [['Af', 'Pe', 'Ji', 'Er']]

# 使用管道字符分割文本提示(权重稍后再分割)
if args.prompts:
    # 对于故事,会有许多短语
    story_phrases = [phrase.strip() for phrase in args.prompts.split("^")]
    
    # 创建所有短语的列表
    all_phrases = []
    for phrase in story_phrases:
        all_phrases.append(phrase.split("|"))
    
    # 取第一短语
    args.prompts = all_phrases[0]
    
# 使用管道字符分割目标图像(权重稍后再分割)
if args.image_prompts:
    args.image_prompts = args.image_prompts.split("|")
    args.image_prompts = [image.strip() for image in args.image_prompts]

# 如果同时设置了 make_video 和 make_zoom_video,给出警告
if args.make_video and args.make_zoom_video:
    print("警告:创建视频和创建缩放视频是互斥的。")
    args.make_video = False
    
# 创建视频步骤目录
if args.make_video or args.make_zoom_video:
    if not os.path.exists('steps'):
        os.mkdir('steps')

# 如果找不到 CUDA,则回退到 CPU,并确保禁用 GPU 视频渲染
# NB. 可能不适用于 AMD 显卡?
if not args.cuda_device == 'cpu' and not torch.cuda.is_available():
    args.cuda_device = 'cpu'
    args.video_fps = 0
    print("警告:未找到 GPU!改为使用 CPU。迭代会很慢。")
    print("可能是 CUDA/ROCm 或正确的 pytorch 版本未正确安装?")

# 如果提供了 video_style_dir,则创建所有图像的列表
if args.video_style_dir:
    print("定位视频帧...")
    video_frame_list = []
    for entry in os.scandir(args.video_style_dir):
        if (entry.path.endswith(".jpg") or entry.path.endswith(".png")) and entry.is_file():
            video_frame_list.append(entry.path)

    # 重置一些选项 - 相同文件名,不同目录
    if not os.path.exists('steps'):
        os.mkdir('steps')

    args.init_image = video_frame_list[0]  # 将第一帧作为初始图像
    filename = os.path.basename(args.init_image)  # 获取文件名
    cwd = os.getcwd()  # 获取当前工作目录
    args.output = os.path.join(cwd, "steps", filename)  # 设置输出路径
    num_video_frames = len(video_frame_list)  # 视频风格化用的帧数

(2)函数sinc(x)的功能是计算 sinc 函数的值,其中 sinc(x) = sin(πx) / (πx),当 x = 0 时返回 1。

def sinc(x):
    return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))

(3)函数lanczos(x, a)生成 Lanczos 内核,用于图像重采样,计算范围在 [-a, a] 之间的 sinc 函数并归一化。

def lanczos(x, a):
    cond = torch.logical_and(-a < x, x < a)
    out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
    return out / out.sum()

(4)函数ramp(ratio, width)的功能是根据给定的比例和宽度生成一个线性增长的值序列。

def ramp(ratio, width):
    n = math.ceil(width / ratio + 1)
    out = torch.empty([n])
    cur = 0
    for i in range(out.shape[0]):
        out[i] = cur
        cur += ratio
    return torch.cat([-out[1:].flip([0]), out])[1:-1]

(5)函数zoom_at(img, x, y, zoom)的功能是对输入图像 img 在 (x, y) 位置进行缩放,使用 LANCZOS 算法重新调整图像大小。

def zoom_at(img, x, y, zoom):
    w, h = img.size
    zoom2 = zoom * 2
    img = img.crop((x - w / zoom2, y - h / zoom2,
                    x + w / zoom2, y + h / zoom2))
    return img.resize((w, h), Image.LANCZOS)

(6)函数random_noise_image(w, h)的功能是生成一个随机噪声图像,图像宽度为 w,高度为 h。

def random_noise_image(w,h):
    random_image = Image.fromarray(np.random.randint(0,255,(w,h,3),dtype=np.dtype('uint8')))
    return random_image

(7)函数gradient_2d(start, stop, width, height, is_horizontal)的功能是生成一个二维渐变图像。根据参数决定是横向还是纵向渐变。

def gradient_2d(start, stop, width, height, is_horizontal):
    if is_horizontal:
        return np.tile(np.linspace(start, stop, width), (height, 1))
    else:
        return np.tile(np.linspace(start, stop, height), (width, 1)).T

(8)函数gradient_3d(width, height, start_list, stop_list, is_horizontal_list)的功能是生成一个三维渐变图像,支持多通道渐变。

def gradient_3d(width, height, start_list, stop_list, is_horizontal_list):
    result = np.zeros((height, width, len(start_list)), dtype=float)

    for i, (start, stop, is_horizontal) in enumerate(zip(start_list, stop_list, is_horizontal_list)):
        result[:, :, i] = gradient_2d(start, stop, width, height, is_horizontal)


    return result

(9)函数random_gradient_image(w, h)的功能是生成一个随机渐变图像,宽度为 w,高度为 h,并包含随机的颜色值。

def random_gradient_image(w,h):

    array = gradient_3d(w, h, (0, 0, np.random.randint(0,255)), (np.random.randint(1,255), np.random.randint(2,255), np.random.randint(3,128)), (True, False, False))

    random_image = Image.fromarray(np.uint8(array))

    return random_image

(10)函数resample(input, size, align_corners=True)的功能是对输入的张量进行重采样操作,使用 Lanczos 核进行插值。

def resample(input, size, align_corners=True):
    n, c, h, w = input.shape
    dh, dw = size

    input = input.view([n * c, 1, h, w])

    if dh < h:
        kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
        pad_h = (kernel_h.shape[0] - 1) // 2
        input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
        input = F.conv2d(input, kernel_h[None, None, :, None])

    if dw < w:
        kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
        pad_w = (kernel_w.shape[0] - 1) // 2
        input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
        input = F.conv2d(input, kernel_w[None, None, None, :])

    input = input.view([n, c, h, w])
    return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)

(11)类ReplaceGrad是一个自定义的PyTorch 类,功能是替换反向传播时的梯度。

class ReplaceGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x_forward, x_backward):
        ctx.shape = x_backward.shape
        return x_forward

    @staticmethod
    def backward(ctx, grad_in):
        return None, grad_in.sum_to_size(ctx.shape)

replace_grad = ReplaceGrad.apply

(12)ClampWithGrad是一个自定义的PyTorch 类,功能是用于在前向传播中限制输入的值范围,并在反向传播中使用相应的梯度。

class ClampWithGrad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, min, max):
        ctx.min = min
        ctx.max = max
        ctx.save_for_backward(input)
        return input.clamp(min, max)

    @staticmethod
    def backward(ctx, grad_in):
        input, = ctx.saved_tensors
        return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None

clamp_with_grad = ClampWithGrad.apply

(13)函数vector_quantize(x, codebook)的功能是实现向量量化,将输入 x 映射到给定的码本 codebook。

def vector_quantize(x, codebook):

    d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T

    indices = d.argmin(-1)

    x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook

    return replace_grad(x_q, x)

(14)类Prompt是一个提示类,用于计算输入与嵌入向量之间的距离,并根据权重进行缩放。

class Prompt(nn.Module):
    def __init__(self, embed, weight=1., stop=float('-inf')):
        super().__init__()
        self.register_buffer('embed', embed)
        self.register_buffer('weight', torch.as_tensor(weight))
        self.register_buffer('stop', torch.as_tensor(stop))

    def forward(self, input):
        input_normed = F.normalize(input.unsqueeze(1), dim=2)
        embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
        dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
        dists = dists * self.weight.sign()
        return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()

(15)函数split_prompt(prompt)的功能是将提示字符串分割为基本提示、权重和停止值。

def split_prompt(prompt):
    vals = prompt.rsplit(':', 2)
    vals = vals + ['', '1', '-inf'][len(vals):]
    return vals[0], float(vals[1]), float(vals[2])

(16)类MakeCutouts的功能是生成图像切割(cutouts),可以应用不同的数据增强方法。该类接收切割大小、切割数量和其他增强参数,在前向传播中返回切割后的图像。

class MakeCutouts(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow # not used with pooling
        
        # Pick your own augments & their order
        augment_list = []
        for item in args.augments[0]:
            if item == 'Ji':
                augment_list.append(K.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=0.7))
            elif item == 'Sh':
                augment_list.append(K.RandomSharpness(sharpness=0.3, p=0.5))
            elif item == 'Gn':
                augment_list.append(K.RandomGaussianNoise(mean=0.0, std=1., p=0.5))
            elif item == 'Pe':
                augment_list.append(K.RandomPerspective(distortion_scale=0.7, p=0.7))
            elif item == 'Ro':
                augment_list.append(K.RandomRotation(degrees=15, p=0.7))
            elif item == 'Af':
                augment_list.append(K.RandomAffine(degrees=15, translate=0.1, shear=5, p=0.7, padding_mode='zeros', keepdim=True)) # border, reflection, zeros
            elif item == 'Et':
                augment_list.append(K.RandomElasticTransform(p=0.7))
            elif item == 'Ts':
                augment_list.append(K.RandomThinPlateSpline(scale=0.8, same_on_batch=True, p=0.7))
            elif item == 'Cr':
                augment_list.append(K.RandomCrop(size=(self.cut_size,self.cut_size), pad_if_needed=True, padding_mode='reflect', p=0.5))
            elif item == 'Er':
                augment_list.append(K.RandomErasing(scale=(.1, .4), ratio=(.3, 1/.3), same_on_batch=True, p=0.7))
            elif item == 'Re':
                augment_list.append(K.RandomResizedCrop(size=(self.cut_size,self.cut_size), scale=(0.1,1),  ratio=(0.75,1.333), cropping_mode='resample', p=0.5))
                
        self.augs = nn.Sequential(*augment_list)
        self.noise_fac = 0.1
        self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
        self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))

    def forward(self, input):
        cutouts = []
        
        for _ in range(self.cutn):            
            # Use Pooling
            cutout = (self.av_pool(input) + self.max_pool(input))/2
            cutouts.append(cutout)
            
        batch = self.augs(torch.cat(cutouts, dim=0))
        
        if self.noise_fac:
            facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
            batch = batch + facs * torch.randn_like(batch)
        return batch

(17)类 MakeCutoutsPoolingUpdate 的功能是生成图像切割,并应用一系列的图像增强操作,同时使用自适应池化技术来合成切割图像。该类通过定义增强操作序列和池化策略,提供灵活的图像处理能力。

class MakeCutoutsPoolingUpdate(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow # Not used with pooling

        self.augs = nn.Sequential(
            K.RandomAffine(degrees=15, translate=0.1, p=0.7, padding_mode='border'),
            K.RandomPerspective(0.7,p=0.7),
            K.ColorJitter(hue=0.1, saturation=0.1, p=0.7),
            K.RandomErasing((.1, .4), (.3, 1/.3), same_on_batch=True, p=0.7),            
        )
        
        self.noise_fac = 0.1
        self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
        self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        
        for _ in range(self.cutn):
            cutout = (self.av_pool(input) + self.max_pool(input))/2
            cutouts.append(cutout)
            
        batch = self.augs(torch.cat(cutouts, dim=0))
        
        if self.noise_fac:
            facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
            batch = batch + facs * torch.randn_like(batch)
        return batch

(18)类 MakeCutoutsNRUpdate 的功能是根据用户选择的增强操作动态生成图像切割,支持多种增强方法的组合。该类在前向传播中随机生成切割的尺寸和位置,然后将其应用指定的增强操作,增强图像多样性。

class MakeCutoutsNRUpdate(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow
        self.noise_fac = 0.1
        
        # Pick your own augments & their order
        augment_list = []
        for item in args.augments[0]:
            if item == 'Ji':
                augment_list.append(K.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=0.7))
            elif item == 'Sh':
                augment_list.append(K.RandomSharpness(sharpness=0.3, p=0.5))
            elif item == 'Gn':
                augment_list.append(K.RandomGaussianNoise(mean=0.0, std=1., p=0.5))
            elif item == 'Pe':
                augment_list.append(K.RandomPerspective(distortion_scale=0.5, p=0.7))
            elif item == 'Ro':
                augment_list.append(K.RandomRotation(degrees=15, p=0.7))
            elif item == 'Af':
                augment_list.append(K.RandomAffine(degrees=30, translate=0.1, shear=5, p=0.7, padding_mode='zeros', keepdim=True)) # border, reflection, zeros
            elif item == 'Et':
                augment_list.append(K.RandomElasticTransform(p=0.7))
            elif item == 'Ts':
                augment_list.append(K.RandomThinPlateSpline(scale=0.8, same_on_batch=True, p=0.7))
            elif item == 'Cr':
                augment_list.append(K.RandomCrop(size=(self.cut_size,self.cut_size), pad_if_needed=True, padding_mode='reflect', p=0.5))
            elif item == 'Er':
                augment_list.append(K.RandomErasing(scale=(.1, .4), ratio=(.3, 1/.3), same_on_batch=True, p=0.7))
            elif item == 'Re':
                augment_list.append(K.RandomResizedCrop(size=(self.cut_size,self.cut_size), scale=(0.1,1),  ratio=(0.75,1.333), cropping_mode='resample', p=0.5))
                
        self.augs = nn.Sequential(*augment_list)


    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
        batch = self.augs(torch.cat(cutouts, dim=0))
        if self.noise_fac:
            facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
            batch = batch + facs * torch.randn_like(batch)
        return batch

(19)类 MakeCutoutsUpdate 的功能是生成随机尺寸的图像切割,并应用一系列固定的增强操作。该类通过定义增强序列和随机生成切割,实现了简单且有效的图像处理。

class MakeCutoutsUpdate(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow
        self.augs = nn.Sequential(
            K.RandomHorizontalFlip(p=0.5),
            K.ColorJitter(hue=0.01, saturation=0.01, p=0.7),
            # K.RandomSolarize(0.01, 0.01, p=0.7),
            K.RandomSharpness(0.3,p=0.4),
            K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'),
            K.RandomPerspective(0.2,p=0.4),)
        self.noise_fac = 0.1


    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
        batch = self.augs(torch.cat(cutouts, dim=0))
        if self.noise_fac:
            facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
            batch = batch + facs * torch.randn_like(batch)
        return batch

(20)类 MakeCutoutsOrig 的功能是生成图像切割,但不应用任何增强或池化操作。该类适用于需要基本切割功能的场景,输出未经过处理的切割图像。

class MakeCutoutsOrig(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
        return clamp_with_grad(torch.cat(cutouts, dim=0), 0, 1)

(21)函数 load_vqgan_model 的功能是根据给定的配置文件和检查点加载 VQGAN 模型。该函数支持多种模型类型的加载,并返回经过初始化的模型实例,以便在后续的图像生成任务中使用。

def load_vqgan_model(config_path, checkpoint_path):
    global gumbel
    gumbel = False
    config = OmegaConf.load(config_path)
    if config.model.target == 'taming.models.vqgan.VQModel':
        model = vqgan.VQModel(**config.model.params)
        model.eval().requires_grad_(False)
        model.init_from_ckpt(checkpoint_path)
    elif config.model.target == 'taming.models.vqgan.GumbelVQ':
        model = vqgan.GumbelVQ(**config.model.params)
        model.eval().requires_grad_(False)
        model.init_from_ckpt(checkpoint_path)
        gumbel = True
    elif config.model.target == 'taming.models.cond_transformer.Net2NetTransformer':
        parent_model = cond_transformer.Net2NetTransformer(**config.model.params)
        parent_model.eval().requires_grad_(False)
        parent_model.init_from_ckpt(checkpoint_path)
        model = parent_model.first_stage_model
    else:
        raise ValueError(f'unknown model type: {config.model.target}')
    del model.loss
    return model

(22)函数 resize_image 的功能是根据指定的输出尺寸调整输入图像的大小。该函数通过计算面积和宽高比,以保持图像的纵横比,并使用高质量的重采样方法进行调整,确保输出图像的质量。

def resize_image(image, out_size):
    ratio = image.size[0] / image.size[1]
    area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])
    size = round((area * ratio)**0.5), round((area / ratio)**0.5)
    return image.resize(size, Image.LANCZOS)

device = torch.device(args.cuda_device)
model = load_vqgan_model(args.vqgan_config, args.vqgan_checkpoint).to(device)
jit = True if "1.7.1" in torch.__version__ else False
perceptor = clip.load(args.clip_model, jit=jit)[0].eval().requires_grad_(False).to(device)


cut_size = perceptor.visual.input_resolution
f = 2**(model.decoder.num_resolutions - 1)

if args.cut_method == 'latest':
    make_cutouts = MakeCutouts(cut_size, args.cutn, cut_pow=args.cut_pow)
elif args.cut_method == 'original':
    make_cutouts = MakeCutoutsOrig(cut_size, args.cutn, cut_pow=args.cut_pow)
elif args.cut_method == 'updated':
    make_cutouts = MakeCutoutsUpdate(cut_size, args.cutn, cut_pow=args.cut_pow)
elif args.cut_method == 'nrupdated':
    make_cutouts = MakeCutoutsNRUpdate(cut_size, args.cutn, cut_pow=args.cut_pow)
else:
    make_cutouts = MakeCutoutsPoolingUpdate(cut_size, args.cutn, cut_pow=args.cut_pow)    

toksX, toksY = args.size[0] // f, args.size[1] // f
sideX, sideY = toksX * f, toksY * f

# Gumbel or not?
if gumbel:
    e_dim = 256
    n_toks = model.quantize.n_embed
    z_min = model.quantize.embed.weight.min(dim=0).values[None, :, None, None]
    z_max = model.quantize.embed.weight.max(dim=0).values[None, :, None, None]
else:
    e_dim = model.quantize.e_dim
    n_toks = model.quantize.n_e
    z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
    z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]


if args.init_image:
    if 'http' in args.init_image:
      img = Image.open(urlopen(args.init_image))
    else:
      img = Image.open(args.init_image)
    pil_image = img.convert('RGB')
    pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
    pil_tensor = TF.to_tensor(pil_image)
    z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1)
elif args.init_noise == 'pixels':
    img = random_noise_image(args.size[0], args.size[1])    
    pil_image = img.convert('RGB')
    pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
    pil_tensor = TF.to_tensor(pil_image)
    z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1)
elif args.init_noise == 'gradient':
    img = random_gradient_image(args.size[0], args.size[1])
    pil_image = img.convert('RGB')
    pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
    pil_tensor = TF.to_tensor(pil_image)
    z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1)
else:
    one_hot = F.one_hot(torch.randint(n_toks, [toksY * toksX], device=device), n_toks).float()
    # z = one_hot @ model.quantize.embedding.weight
    if gumbel:
        z = one_hot @ model.quantize.embed.weight
    else:
        z = one_hot @ model.quantize.embedding.weight

    z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2) 
    #z = torch.rand_like(z)*2						# NR: check

z_orig = z.clone()
z.requires_grad_(True)

pMs = []
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                  std=[0.26862954, 0.26130258, 0.27577711])

 
if args.prompts:
    for prompt in args.prompts:
        txt, weight, stop = split_prompt(prompt)
        embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
        pMs.append(Prompt(embed, weight, stop).to(device))

for prompt in args.image_prompts:
    path, weight, stop = split_prompt(prompt)
    img = Image.open(path)
    pil_image = img.convert('RGB')
    img = resize_image(pil_image, (sideX, sideY))
    batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(device))
    embed = perceptor.encode_image(normalize(batch)).float()
    pMs.append(Prompt(embed, weight, stop).to(device))

for seed, weight in zip(args.noise_prompt_seeds, args.noise_prompt_weights):
    gen = torch.Generator().manual_seed(seed)
    embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen)
    pMs.append(Prompt(embed, weight).to(device))

(23)函数 get_opt 的功能是根据给定的优化器名称和学习率创建对应的优化器实例,该函数支持多种优化器(如 Adam、AdamW、RMSprop 等),并返回初始化后的优化器对象,用于后续的训练过程。

def get_opt(opt_name, opt_lr):
    if opt_name == "Adam":
        opt = optim.Adam([z], lr=opt_lr)	# LR=0.1 (Default)
    elif opt_name == "AdamW":
        opt = optim.AdamW([z], lr=opt_lr)	
    elif opt_name == "Adagrad":
        opt = optim.Adagrad([z], lr=opt_lr)	
    elif opt_name == "Adamax":
        opt = optim.Adamax([z], lr=opt_lr)	
    elif opt_name == "DiffGrad":
        opt = DiffGrad([z], lr=opt_lr, eps=1e-9, weight_decay=1e-9) # NR: Playing for reasons
    elif opt_name == "AdamP":
        opt = AdamP([z], lr=opt_lr)		    
    elif opt_name == "RAdam":
        opt = optim.RAdam([z], lr=opt_lr)		    
    elif opt_name == "RMSprop":
        opt = optim.RMSprop([z], lr=opt_lr)
    else:
        print("Unknown optimiser. Are choices broken?")
        opt = optim.Adam([z], lr=opt_lr)
    return opt

(24)在下面的代码中,设置优化器、输出当前训练配置和初始化随机种子。通过调用 get_opt 函数创建优化器,并打印出当前使用的设备、优化器、文本和图像提示、初始图像及噪声权重等信息。此外,代码还检查用户提供的随机种子,确保后续操作的可重复性。

opt = get_opt(args.optimiser, args.step_size)

print('Using device:', device)
print('Optimising using:', args.optimiser)

if args.prompts:
    print('Using text prompts:', args.prompts)  
if args.image_prompts:
    print('Using image prompts:', args.image_prompts)
if args.init_image:
    print('Using initial image:', args.init_image)
if args.noise_prompt_weights:
    print('Noise prompt weights:', args.noise_prompt_weights)    


if args.seed is None:
    seed = torch.seed()
else:
    seed = args.seed  
torch.manual_seed(seed)
print('Using seed:', seed)

(25)函数 synth 的功能是对输入的潜在变量 z 进行向量量化处理,并解码为可视化图像。该函数根据设定的 Gumbel 参数选择不同的量化方式,并通过模型的解码器生成图像,输出的图像经过归一化处理以适合可视化。

def synth(z):
    if gumbel:
        z_q = vector_quantize(z.movedim(1, 3), model.quantize.embed.weight).movedim(3, 1)
    else:
        z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1)
    return clamp_with_grad(model.decode(z_q).add(1).div(2), 0, 1)

(26)函数 checkin 的功能是在训练过程中定期输出当前的迭代信息,包括损失值和相关的损失项。该函数还会将生成的图像保存为 PNG 文件,并附加用户输入的文本提示信息,以便于后续查看。

#@torch.no_grad()
@torch.inference_mode()
def checkin(i, losses):
    losses_str = ', '.join(f'{loss.item():g}' for loss in losses)
    tqdm.write(f'i: {i}, loss: {sum(losses).item():g}, losses: {losses_str}')
    out = synth(z)
    info = PngImagePlugin.PngInfo()
    info.add_text('comment', f'{args.prompts}')
    TF.to_pil_image(out[0].cpu()).save(args.output, pnginfo=info) 	

(27)函数 ascend_txt 的功能是计算当前生成图像与文本提示之间的损失。该函数首先生成图像并进行编码,然后根据用户定义的初始权重和文本提示,计算对应的损失值,返回损失列表。

def ascend_txt():
    global i
    out = synth(z)
    iii = perceptor.encode_image(normalize(make_cutouts(out))).float()
    
    result = []

    if args.init_weight:
        # result.append(F.mse_loss(z, z_orig) * args.init_weight / 2)
        result.append(F.mse_loss(z, torch.zeros_like(z_orig)) * ((1/torch.tensor(i*2 + 1))*args.init_weight) / 2)

    for prompt in pMs:
        result.append(prompt(iii))
    
    if args.make_video:    
        img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:]
        img = np.transpose(img, (1, 2, 0))
        imageio.imwrite('./steps/' + str(i) + '.png', np.array(img))

    return result # return loss

(28)函数 train 的功能是执行一次训练迭代,包括清空梯度、计算损失、反向传播以及更新模型参数。该函数还会在指定的频率下调用 checkin 函数以输出训练状态,并确保生成的潜在变量 z 保持在指定范围内。

def train(i):
    opt.zero_grad(set_to_none=True)
    lossAll = ascend_txt()
    
    if i % args.display_freq == 0:
        checkin(i, lossAll)
       
    loss = sum(lossAll)
    loss.backward()
    opt.step()
    
    #with torch.no_grad():
    with torch.inference_mode():
        z.copy_(z.maximum(z_min).minimum(z_max))

(29)下面的代码用于生成视频,结合图像处理和训练过程,通过调整图像的缩放和文本提示,逐帧生成图像并保存为视频。它首先初始化一些计数器和参数,然后在一个循环中不断生成图像,根据设定的频率进行缩放和更新提示,训练模型以优化图像。最后,它将生成的图像序列合成视频,并可以选择使用硬件编码和帧插值来提升视频质量。

i = 0  # 迭代计数器
j = 0  # 缩放视频帧计数器
p = 1  # 短语计数器
smoother = 0  # 平滑计数器
this_video_frame = 0  # 用于视频样式

# 调整学习率/优化器
# variable_lr = args.step_size
# optimiser_list = [['Adam',0.075],['AdamW',0.125],['Adagrad',0.2],['Adamax',0.125],['DiffGrad',0.075],['RAdam',0.125],['RMSprop',0.02]]

# 开始训练
try:
    with tqdm() as pbar:
        while True:            
            # 更改生成的图像
            if args.make_zoom_video:
                if i % args.zoom_frequency == 0:
                    out = synth(z)
                    
                    # 保存图像
                    img = np.array(out.mul(255).clamp(0, 255)[0].cpu().detach().numpy().astype(np.uint8))[:,:,:]
                    img = np.transpose(img, (1, 2, 0))
                    imageio.imwrite('./steps/' + str(j) + '.png', np.array(img))

                    # 是否开始缩放?                    
                    if args.zoom_start <= i:
                        # 将 z 转换回 PIL 图像                    
                        # pil_image = TF.to_pil_image(out[0].cpu())
                        
                        # 将 numpy 数组转换为 PIL 图像
                        pil_image = Image.fromarray(np.array(img).astype('uint8'), 'RGB')
                                                
                        # 缩放
                        if args.zoom_scale != 1:
                            pil_image_zoom = zoom_at(pil_image, sideX/2, sideY/2, args.zoom_scale)
                        else:
                            pil_image_zoom = pil_image
                        
                        # 移动 - 参考文档:https://pillow.readthedocs.io/en/latest/reference/ImageChops.html
                        if args.zoom_shift_x or args.zoom_shift_y:
                            # 此操作会环绕图像
                            pil_image_zoom = ImageChops.offset(pil_image_zoom, args.zoom_shift_x, args.zoom_shift_y)
                        
                        # 将图像再次转换回张量
                        pil_tensor = TF.to_tensor(pil_image_zoom)
                        
                        # 重新编码
                        z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1)
                        z_orig = z.clone()
                        z.requires_grad_(True)

                        # 重新创建优化器
                        opt = get_opt(args.optimiser, args.step_size)
                    
                    # 下一个
                    j += 1
            
            # 更改文本提示
            if args.prompt_frequency > 0:
                if i % args.prompt_frequency == 0 and i > 0:
                    # 如果短语不足,循环使用
                    if p >= len(all_phrases):
                        p = 0
                    
                    pMs = []
                    args.prompts = all_phrases[p]

                    # 显示用户正在更改提示                                
                    print(args.prompts)
                    
                    for prompt in args.prompts:
                        txt, weight, stop = split_prompt(prompt)
                        embed = perceptor.encode_text(clip.tokenize(txt).to(device)).float()
                        pMs.append(Prompt(embed, weight, stop).to(device))
                                        
                    # p += 1
            
            # 训练时间
            train(i)
            
            # 是否准备停止?
            if i == args.max_iterations:
                if not args.video_style_dir:
                    # 训练完成
                    break
                else:                    
                    if this_video_frame == (num_video_frames - 1):
                        # 训练完成
                        make_styled_video = True
                        break
                    else:
                        # 下一帧视频
                        this_video_frame += 1

                        # 重置迭代计数
                        i = -1
                        pbar.reset()
                                                
                        # 加载下一帧,重置一些选项 - 相同的文件名,不同的目录
                        args.init_image = video_frame_list[this_video_frame]
                        print("下一帧: ", args.init_image)

                        if args.seed is None:
                            seed = torch.seed()
                        else:
                            seed = args.seed  
                        torch.manual_seed(seed)
                        print("随机种子: ", seed)

                        filename = os.path.basename(args.init_image)
                        args.output = os.path.join(cwd, "steps", filename)

                        # 加载并调整图像大小
                        img = Image.open(args.init_image)
                        pil_image = img.convert('RGB')
                        pil_image = pil_image.resize((sideX, sideY), Image.LANCZOS)
                        pil_tensor = TF.to_tensor(pil_image)
                        
                        # 重新编码
                        z, *_ = model.encode(pil_tensor.to(device).unsqueeze(0) * 2 - 1)
                        z_orig = z.clone()
                        z.requires_grad_(True)

                        # 重新创建优化器
                        opt = get_opt(args.optimiser, args.step_size)

            i += 1
            pbar.update()
except KeyboardInterrupt:
    pass

# 所有操作完成 :)

# 视频生成
if args.make_video or args.make_zoom_video:
    init_frame = 1  # 初始视频帧
    if args.make_zoom_video:
        last_frame = j
    else:
        last_frame = i  # 如果帧数不存在会引发错误

    length = args.video_length  # 视频的预期时长(秒)

    min_fps = 10
    max_fps = 60

    total_frames = last_frame - init_frame

    frames = []
    tqdm.write('正在生成视频...')
    for i in range(init_frame, last_frame):
        temp = Image.open("./steps/" + str(i) + '.png')
        keep = temp.copy()
        frames.append(keep)
        temp.close()
    
    if args.output_video_fps > 9:
        # 硬件编码和视频帧插值
        print("正在创建插值帧...")
        ffmpeg_filter = f"minterpolate='mi_mode=mci:me=hexbs:me_mode=bidir:mc_mode=aobmc:vsbmc=1:mb_size=8:search_param=32:fps={args.output_video_fps}'"
        output_file = re.compile('\.png$').sub('.mp4', args.output)
        try:
            p = Popen(['ffmpeg',
                       '-y',
                       '-f', 'image2pipe',
                       '-vcodec', 'png',
                       '-r', str(args.input_video_fps),               
                       '-i',
                       '-',
                       '-b:v', '10M',
                       '-vcodec', 'h264_nvenc',
                       '-pix_fmt', 'yuv420p',
                       '-strict', '-2',
                       '-filter:v', f'{ffmpeg_filter}',
                       '-metadata', f'comment={args.prompts}',
                       output_file], stdin=PIPE)
        except FileNotFoundError:
            print("ffmpeg 命令失败 - 检查安装情况")
        for im in tqdm(frames):
            im.save(p.stdin, 'PNG')
        p.stdin.close()
        p.wait()
    else:
        # CPU
        fps = np.clip(total_frames / length, min_fps, max_fps)
        output_file = re.compile('\.png$').sub('.mp4', args.output)
        try:
            p = Popen(['ffmpeg',
                       '-y',
                       '-f', 'image2pipe',
                       '-vcodec', 'png',
                       '-r', str(fps),
                       '-i',
                       '-',
                       '-vcodec', 'libx264',
                       '-r', str(fps),
                       '-pix_fmt', 'yuv420p',
                       '-crf', '17',
                       '-preset', 'veryslow',
                       '-metadata', f'comment={args.prompts}',
                       output_file], stdin=PIPE)
        except FileNotFoundError:
            print("ffmpeg 命令失败 - 检查安装情况")        
        for im in tqdm(frames):
            im.save(p.stdin, 'PNG')
        p.stdin.close()
        p.wait() 

  • 12
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

码农三叔

感谢鼓励

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值