BiGR模型是一种新型的条件图像生成模型,它通过使用紧凑的二进制潜在代码进行生成训练,以增强生成和表示能力。
作为首个在同一框架内统一生成和判别任务的条件生成模型,BiGR在保持高生成质量的同时,能有效地执行视觉生成、辨别和编辑等多种视觉任务。
BiGR建立在Llama主干之上,结合了掩码标记预测和二进制转码器。使用加权二进制交叉熵损失进行训练,以重建掩码标记。
BiGR可以灵活地用于各种视觉应用,例如以零样本方式进行修复、去除修复、编辑、插值和丰富,而无需针对特定任务进行结构更改或参数微调。
github项目地址:https://github.com/haoosz/BiGR。
一、环境安装
1、python环境
建议安装python版本在3.10以上。
2、pip库安装
pip install torch==2.2.0+cu118 torchvision==0.17.0+cu118 torchaudio==2.2.0 --extra-index-url https://download.pytorch.org/whl/cu118
pip install timm diffusers accelerate einops scipy opencv-python -i https://pypi.tuna.tsinghua.edu.cn/simple
3、模型下载:
git lfs install
git clone https://huggingface.co/haoosz/BiGR
二、功能测试
1、运行测试:
(1)Image generation的python代码调用测试
import torch
from torchvision.utils import save_image
import argparse
import os
from time import time
from glob import glob
import numpy as np
from hparams import get_vqgan_hparams
from bae.binaryae import BinaryAutoEncoder, load_pretrain
from llama.load_bigr import load_bigr
def setup_device():
return "cuda" if torch.cuda.is_available() else "cpu"
def configure_torch(seed):
torch.manual_seed(seed)
np.random.seed(seed)
torch.set_grad_enabled(False)
def initialize_models(args, args_ae, device):
model = load_bigr(args, args_ae, device)
binaryae = BinaryAutoEncoder(args_ae).to(device)
binaryae = load_pretrain(binaryae, args.ckpt_bae)
print(f"The code length of B-AE is set to {args_ae.codebook_size}")
print(f"B-AE checkpoint loaded from {args.ckpt_bae}")
print(f"GPT Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"MLP Parameters in GPT: {sum(p.numel() for p in model.denoise_mlp.parameters()):,}")
return model, binaryae
def generate_samples(model, bae, args, seed, image_size):
configure_torch(seed)
device = setup_device()
model.eval()
bae.eval()
latent_size = image_size // 16
class_labels = [207, 360, 387, 974, 88, 979, 417, 279]
y = torch.tensor(class_labels, device=device)
start_time = time()
samples = model.generate_with_cfg(
cond=y,
max_new_tokens=latent_size ** 2,
cond_padding=args.cls_token_num,
num_iter=args.num_sample_iter,
out_dim=bae.codebook_size,
cfg_scale=args.cfg_scale,
cfg_schedule=args.cfg_schedule,
gumbel_temp=args.gumbel_temp,
gumbel_schedule=args.gumbel_schedule,
sample_logits=True,
proj_emb=None
)
end_time = time()
print(f"Sample time: {end_time - start_time}")
samples = samples.float().transpose(1, 2).reshape(y.size(0), -1, latent_size, latent_size)
return bae.decode(samples)
def save_samples(samples, save_path):
os.makedirs(save_path, exist_ok=True)
sample_index = len(glob(f"{save_path}/*"))
filename = os.path.join(save_path, f"{sample_index:03d}.png")
save_image(samples, filename, nrow=4, normalize=True, value_range=(0, 1))
print(f"Samples saved to {filename}")
def main(args):
device = setup_device()
model, binaryae = initialize_models(args, args_ae, device)
samples = generate_samples(model, binaryae, args, seed=args.seed, image_size=args.image_size)
save_samples(samples, args.save_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="BiGR-L")
parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
parser.add_argument("--num-classes", type=int, default=1000)
parser.add_argument("--ckpt", type=str, default=None, help="DiT checkpoint path.")
parser.add_argument("--save-path", type=str, default="samples")
parser.add_argument("--ckpt_bae", type=str, required=True, help='B-AE checkpoint path')
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--dataset", type=str, required=True)
parser.add_argument("--cls-token-num", type=int, default=1, help="Max token number for condition input")
parser.add_argument("--dropout-p", type=float, default=0.1, help="Dropout probability for residual and FFN")
parser.add_argument("--token-dropout-p", type=float, default=0.0, help="Token dropout probability")
parser.add_argument("--drop-path-rate", type=float, default=0.0, help="Stochastic depth decay rate")
parser.add_argument("--use_adaLN", action='store_true')
parser.add_argument("--mixed-precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
parser.add_argument("--no-compile", action='store_true')
parser.add_argument("--p_flip", action='store_true', help='Predict z0 or z0 XOR zt (flipping)')
parser.add_argument("--focal", type=float, default=-1, help='Focal coefficient')
parser.add_argument("--alpha", type=float, default=-1, help='Alpha coefficient')
parser.add_argument("--aux", type=float, default=0.0, help='VLB weight')
parser.add_argument("--n_repeat", type=int, default=1, help='Repeat sample timesteps')
parser.add_argument("--n_sample_steps", type=int, default=256, help="Sample time steps for diffusion training")
parser.add_argument("--seq_len", type=int, default=256)
parser.add_argument("--cfg-scale", type=float, default=4.0)
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for sampling")
parser.add_argument("--num_sample_iter", type=int, default=10)
parser.add_argument("--gumbel_temp", type=float, default=0.)
parser.add_argument("--cfg_schedule", type=str, default='constant', choices=['constant', 'linear'])
parser.add_argument("--infer_steps", type=int, default=100, help="Inference time steps for diffusion")
parser.add_argument("--gumbel_schedule", type=str, default='constant', choices=['constant', 'down', 'up'])
args_ae = get_vqgan_hparams(parser)
args = parser.parse_args()
args_ae.img_size = args.image_size
main(args)
未完......
更多详细的欢迎关注:杰哥新技术