首个统一生成和判别任务的条件生成模型BiGR分享

BiGR模型是一种新型的条件图像生成模型,它通过使用紧凑的二进制潜在代码进行生成训练,以增强生成和表示能力。

作为首个在同一框架内统一生成和判别任务的条件生成模型,BiGR在保持高生成质量的同时,能有效地执行视觉生成、辨别和编辑等多种视觉任务。

BiGR建立在Llama主干之上,结合了掩码标记预测和二进制转码器。使用加权二进制交叉熵损失进行训练,以重建掩码标记。

BiGR可以灵活地用于各种视觉应用,例如以零样本方式进行修复、去除修复、编辑、插值和丰富,而无需针对特定任务进行结构更改或参数微调。

github项目地址:https://github.com/haoosz/BiGR。

一、环境安装

1、python环境

建议安装python版本在3.10以上。

2、pip库安装

pip install torch==2.2.0+cu118 torchvision==0.17.0+cu118 torchaudio==2.2.0 --extra-index-url https://download.pytorch.org/whl/cu118

pip install timm diffusers accelerate einops scipy opencv-python -i https://pypi.tuna.tsinghua.edu.cn/simple

3、模型下载

git lfs install

git clone https://huggingface.co/haoosz/BiGR

、功能测试

1、运行测试

(1)Image generation的python代码调用测试

import torch
from torchvision.utils import save_image
import argparse
import os
from time import time
from glob import glob
import numpy as np

from hparams import get_vqgan_hparams
from bae.binaryae import BinaryAutoEncoder, load_pretrain
from llama.load_bigr import load_bigr

def setup_device():
    return "cuda" if torch.cuda.is_available() else "cpu"

def configure_torch(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.set_grad_enabled(False)

def initialize_models(args, args_ae, device):
    model = load_bigr(args, args_ae, device)
    binaryae = BinaryAutoEncoder(args_ae).to(device)
    binaryae = load_pretrain(binaryae, args.ckpt_bae)
    
    print(f"The code length of B-AE is set to {args_ae.codebook_size}")
    print(f"B-AE checkpoint loaded from {args.ckpt_bae}")
    print(f"GPT Parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"MLP Parameters in GPT: {sum(p.numel() for p in model.denoise_mlp.parameters()):,}")
    
    return model, binaryae

def generate_samples(model, bae, args, seed, image_size):
    configure_torch(seed)
    device = setup_device()
    
    model.eval()
    bae.eval()

    latent_size = image_size // 16
    class_labels = [207, 360, 387, 974, 88, 979, 417, 279]
    y = torch.tensor(class_labels, device=device)
    
    start_time = time()
    samples = model.generate_with_cfg(
        cond=y,
        max_new_tokens=latent_size ** 2,
        cond_padding=args.cls_token_num,
        num_iter=args.num_sample_iter,
        out_dim=bae.codebook_size,
        cfg_scale=args.cfg_scale,
        cfg_schedule=args.cfg_schedule,
        gumbel_temp=args.gumbel_temp,
        gumbel_schedule=args.gumbel_schedule,
        sample_logits=True,
        proj_emb=None
    )
    end_time = time()
    print(f"Sample time: {end_time - start_time}")
    
    samples = samples.float().transpose(1, 2).reshape(y.size(0), -1, latent_size, latent_size)
    return bae.decode(samples)

def save_samples(samples, save_path):
    os.makedirs(save_path, exist_ok=True)
    sample_index = len(glob(f"{save_path}/*"))
    filename = os.path.join(save_path, f"{sample_index:03d}.png")
    save_image(samples, filename, nrow=4, normalize=True, value_range=(0, 1))
    print(f"Samples saved to {filename}")

def main(args):
    device = setup_device()
    model, binaryae = initialize_models(args, args_ae, device)
    samples = generate_samples(model, binaryae, args, seed=args.seed, image_size=args.image_size)
    save_samples(samples, args.save_path)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="BiGR-L")
    parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
    parser.add_argument("--num-classes", type=int, default=1000)
    parser.add_argument("--ckpt", type=str, default=None, help="DiT checkpoint path.")
    parser.add_argument("--save-path", type=str, default="samples")
    parser.add_argument("--ckpt_bae", type=str, required=True, help='B-AE checkpoint path')
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--dataset", type=str, required=True)

    parser.add_argument("--cls-token-num", type=int, default=1, help="Max token number for condition input")
    parser.add_argument("--dropout-p", type=float, default=0.1, help="Dropout probability for residual and FFN")
    parser.add_argument("--token-dropout-p", type=float, default=0.0, help="Token dropout probability")
    parser.add_argument("--drop-path-rate", type=float, default=0.0, help="Stochastic depth decay rate")
    parser.add_argument("--use_adaLN", action='store_true')
    parser.add_argument("--mixed-precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) 
    parser.add_argument("--no-compile", action='store_true')
    parser.add_argument("--p_flip", action='store_true', help='Predict z0 or z0 XOR zt (flipping)')
    parser.add_argument("--focal", type=float, default=-1, help='Focal coefficient')
    parser.add_argument("--alpha", type=float, default=-1, help='Alpha coefficient')
    parser.add_argument("--aux", type=float, default=0.0, help='VLB weight')
    parser.add_argument("--n_repeat", type=int, default=1, help='Repeat sample timesteps')
    parser.add_argument("--n_sample_steps", type=int, default=256, help="Sample time steps for diffusion training")
    parser.add_argument("--seq_len", type=int, default=256)
    
    parser.add_argument("--cfg-scale", type=float, default=4.0)
    parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for sampling")
    parser.add_argument("--num_sample_iter", type=int, default=10)
    parser.add_argument("--gumbel_temp", type=float, default=0.)
    parser.add_argument("--cfg_schedule", type=str, default='constant', choices=['constant', 'linear'])
    parser.add_argument("--infer_steps", type=int, default=100, help="Inference time steps for diffusion")
    parser.add_argument("--gumbel_schedule", type=str, default='constant', choices=['constant', 'down', 'up'])
    
    args_ae = get_vqgan_hparams(parser)
    args = parser.parse_args()
    
    args_ae.img_size = args.image_size
    main(args)

未完......

更多详细的欢迎关注:杰哥新技术

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值