layout生成

# 导入所需的库和模块
import os
import os.path as op
import json
import pdb
import clip
import torch
import numpy as np
from tqdm import tqdm
import time
import random
import argparse
import openai
from transformers import GPT2TokenizerFast, LlamaForCausalLM, LlamaTokenizer
import transformers
from utils import *

# 配置 OpenAI API,这通常用于调用 OpenAI 提供的模型
openai.organization = ""
openai.api_key = ""

# 定义不同的语言模型和它们的 API 路径
llm_name2id = {
    'llama-7b': 'meta-llama/Llama-2-7b-hf',
    'llama-13b': 'meta-llama/Llama-2-13b-hf',
    'llama-70b': 'meta-llama/Llama-2-70b-hf',
    'gpt3.5': 'text-davinci-003',
    'gpt3.5-chat': 'gpt-3.5-turbo',
    'gpt4': 'gpt-4',
}

# 设置命令行参数解析器,用于从用户获取配置选项
parser = argparse.ArgumentParser(prog='LayoutGPT: text-based image layout planning', description='Use LayoutGPT to generate image layouts.')
parser.add_argument('--input_info_dir', type=str, default='./dataset/NSR-1K')
parser.add_argument('--base_output_dir', type=str, default='./llm_output')
parser.add_argument('--setting', type=str, default='counting', choices=['counting', 'spatial'])
parser.add_argument('--matching_content_type', type=str, default='visual')
parser.add_argument('--llm_type', type=str, default='gpt4', choices=list(llm_name2id.keys()))
parser.add_argument('--icl_type', type=str, default='k-similar', choices=['fixed-random', 'k-similar'])
parser.add_argument('--K', type=int, default=8)
parser.add_argument('--gpt_input_length_limit', type=int, default=3000)
parser.add_argument('--canvas_size', type=int, default=256)
parser.add_argument("--n_iter", type=int, default=1)
parser.add_argument("--test", action='store_true')
parser.add_argument('--verbose', default=False, action='store_true')
args = parser.parse_args()

# 根据参数类型加载不同的特征和模型
if args.icl_type == 'k-similar':
    # 初始化和加载 CLIP 模型
    clip_feature_name = 'ViT-L/14'.lower().replace('/', '-')
    device = "cuda" if torch.cuda.is_available() else "cpu"
    clip_model, preprocess = clip.load('ViT-L/14', device=device)
    clip_model = clip_model.to(device)

# 加载特征数据,用于计算文本与图像间的相似度
def load_features(content):
    """Load visual/text features from npz file"""
    np_filename = os.path.join(
        args.input_info_dir, args.setting,
        f'train.{args.setting}.{clip_feature_name}.{content}.npz',
    )
    feature_list = np.load(np_filename)['feature_list']
    features = torch.HalfTensor(feature_list).to(device)
    features /= features.norm(dim=-1, keepdim=True)
    return features

# 生成图像布局的提示文本
def create_exemplar_prompt(caption, object_list, canvas_size, is_chat=False):
    if is_chat:
        prompt = ''
    else:
        prompt = f'\nPrompt: {caption}\nLayout:\n'
    for obj_info in object_list:
        category, bbox = obj_info
        coord_list = [int(i*canvas_size) for i in bbox]
        x, y, w, h = coord_list
        prompt += f'{category} {{height: {h}px; width: {w}px; top: {y}px; left: {x}px; }}\n'
    return prompt

# 根据用户输入和模型类型构造适当的提示,以供大型语言模型生成布局
def form_prompt_for_chatgpt(text_input, top_k, tokenizer, supporting_examples=None, features=None):
    # 创建初始提示和配置,后续根据模型类型和选择继续配置提示
    message_list = []
    system_prompt = 'Instruction: Given a sentence prompt that will be used to generate an image, plan the layout of the image.' \
                'The generated layout should follow the CSS style, where each line starts with the object description ' \
                'and is followed by its absolute position. ' \
                'Formally, each line should be like "object {{width: ?px; height: ?px; left: ?px; top: ?px; }}". ' \
                'The image is {}px wide and {}px high. ' \
                'Therefore, all properties of the positions should not exceed {}px, ' \
                'including the addition of left and width and the addition of top and height. \n'.format(args.canvas_size, args.canvas_size, args.canvas_size)
    message_list.append({'role': 'system', 'content': system_prompt})
    final_prompt = f'Prompt: {text_input}\nLayout:'
    total_length = len(tokenizer(system_prompt + final_prompt)['input_ids'])
    
    if args.icl_type == 'k-similar':
        # 使用 CLIP 模型计算相似度,并选择最相关的支持例子
        text_inputs = clip.tokenize(text_input, truncate=True).to(device)
        text_features = clip_model.encode_text(text_inputs)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        similarity = (100.0 * text_features @ features.T).softmax(dim=-1)
        _, indices = similarity[0].topk(top_k)
        supporting_examples = [supporting_examples[idx] for idx in indices]
    # 其余部分为对模型生成的响应进行处理和格式化,然后将最终结果输出到文件中

    # 遍历相关支持示例,检查生成的提示长度是否超限
    for supporting_example in supporting_examples:
        # 根据设置选择适当的示例格式
        if args.setting == 'counting':
            current_prompting_example = create_exemplar_prompt(
                caption=supporting_example['prompt'],
                object_list=supporting_example['object_list'],
                canvas_size=args.canvas_size,
            )
        else:
            current_prompting_example = create_exemplar_prompt(
                caption=supporting_example['prompt'],
                object_list=[supporting_example['obj1'], supporting_example['obj2']],
                canvas_size=args.canvas_size,
            )
        # 计算当前提示示例的长度
        cur_len = len(tokenizer(current_prompting_example)['input_ids'])
        # 如果添加这个示例后总长度超过限制,则停止添加
        if total_length + cur_len > args.gpt_input_length_limit:
            break
        # 将当前示例加到提示字符串的开始,保证最相关的示例最先展示
        prompting_examples = current_prompting_example + prompting_examples
        # 更新总长度
        total_length += cur_len

    # 将所有示例和最终用户输入合并成一个完整的提示
    prompting_examples += last_example
    rtn_prompt += prompting_examples

    # 返回构造的完整提示字符串
    return rtn_prompt
class StoppingCriteriaICL(transformers.StoppingCriteria):
    def __init__(self, stops=[],) -> None:
        super().__init__()
        # 将停止标记列表转移到 GPU,以便快速比较
        self.stops = [s.to('cuda') for s in stops]
    
    def __call__(self, input_ids, scores, **kwargs):
        # 检查输入中是否存在停止标记
        for stop in self.stops:
            if torch.all(stop == input_ids[0][-len(stop):]):
                return True
        return False

这行定义了一个函数 create_exemplar_prompt,接受四个参数:caption(图像的描述性文本),object_list(图像中对象的列表),canvas_size(画布大小,影响布局尺寸),和 is_chat(布尔值,控制输出格式是否适用于聊天模式)。

def create_exemplar_prompt(caption, object_list, canvas_size, is_chat=False):

这部分是一个条件判断。如果 is_chatTrue,则 prompt 初始化为空字符串。如果为 False,则 prompt 初始化为包含标题和布局关键字的格式化字符串,即在提示文本后开始一个新的布局描述。

    if is_chat:
        prompt = ''
    else:
        prompt = f'\nPrompt: {caption}\nLayout:\n'

这行开始一个循环,遍历 object_list 中的每个对象,每个对象的信息存储在 obj_info 变量中。

这行将 obj_info 解包为两个变量:category(对象的类别)和 bbox(对象在图像中的边界框坐标)。这行使用列表推导式计算边界框的实际像素坐标。bbox 坐标乘以 canvas_size 后转换为整数,得到在画布上的具体位置。将 coord_list 解包为四个变量:x(左上角横坐标),y(左上角纵坐标),w(宽度),h(高度)。向 prompt 字符串追加格式化文本,描述每个对象的类别和其在画布上的具体位置和大小。

    for obj_info in object_list:
        category, bbox = obj_info
        coord_list = [int(i*canvas_size) for i in bbox]
        x, y, w, h = coord_list
        prompt += f'{category} {{height: {h}px; width: {w}px; top: {y}px; left: {x}px; }}\n'
    return prompt

定义一个函数,用来为聊天式 GPT 模型生成适当的提示,接收参数:text_input(用户的文本输入),top_k(选择相关度最高的 k 个例子),tokenizer(用于文本分词的工具),supporting_examples(支持性例子列表),features(特征向量,用于计算相似度)。

这个类是 transformers.StoppingCriteria 的一个子类,定义了自定义的停止条件,用于文本生成过程。

def llama_generation(prompt_for_llama, model, args, eos_token_id=2, stop_criteria=None):
    # 存储生成的响应
    responses = []
    # 根据指定的迭代次数生成多个响应
    for _ in range(args.n_iter):
        responses += model(
                    prompt_for_llama,
                    do_sample=True,
                    num_return_sequences=1,
                    eos_token_id=eos_token_id,
                    temperature=0.7,
                )
    # 提取并返回生成的文本和原始响应
    response_text = [r['generated_text'] for r in responses]
    return response_text, responses

这个函数用于使用 LLaMA 模型进行文本生成。

def gpt_generation(prompt_for_gpt, f_gpt_create, args, **kwargs):
    # 设置调用模型的参数
    input_kwargs = {
        "model": args.llm_id,
        "temperature": 0.7,
        "max_tokens": 256,
        "top_p": 1.0,
        "frequency_penalty": 0.0,
        "presence_penalty": 0.0,
        "stop": "Prompt:",
        "n": args.n_iter,
    }
    # 根据模型类型调整输入格式
    if args.llm_type == 'gpt3.5':
        input_kwargs["prompt"] = prompt_for_gpt
    else:
        input_kwargs["messages"] = prompt_for_gpt
    # 调用模型生成文本
    response = f_gpt_create(**input_kwargs)

    # 根据模型类型处理并返回响应
    if args.llm_type == 'gpt3.5':
        response_text = [r["text"] for r in response.choices]
    else:
        response_text = [r["message"]["content"] for r in response.choices]

    return response_text, response

这个函数用于使用 GPT 模型进行文本生成。

def _main(args):
    # 检查输出目录是否已存在处理结果
    args.output_dir = os.path.join(args.base_output_dir, args.setting)
    os.makedirs(args.output_dir, exist_ok=True)
    output_filename = os.path.join(args.output_dir, f'{args.llm_type}.{args.setting}.{args.icl_type}.k_{args.K}.px_{args.canvas_size}.json')
    if os.path.exists(output_filename):
        print(f'{output_filename} have been processed.')
        return

    # 加载验证集示例
    val_example_files = os.path.join(
        args.input_info_dir, args.setting,
        f'{args.setting}.val.json',
    )
    val_example_list = json.load(open(val_example_files))
    if args.test:
        val_example_list = val_example_list[:3]

    # 加载训练集示例
    train_example_files = os.path.join(
        args.input_info_dir, args.setting,
        f'{args.setting}.train.json',
    )
    train_examples = json.load(open(train_example_files))
    if args.icl_type == 'fixed-random':
        random.seed(42)
        random.shuffle(train_examples)
        supporting_examples = train_examples[:args.K]
        features = None
    elif args.icl_type == 'k-similar':
        supporting_examples = train_examples
        features = load_features(args.matching_content_type)

    # 根据模型类型选择合适的提示构造函数和生成函数
    args.llm_id = llm_name2id[args.llm_type]
    all_prediction_list = []
    all_responses = []

    if 'llama' in args.llm_type:
        f_form_prompt = form_prompt_for_gpt3

        tokenizer = LlamaTokenizer.from_pretrained(args.llm_id)
        stop_ids = [tokenizer(w, return_tensors="pt", add_special_tokens=False).input_ids.squeeze()[1:] for w in ["\n\n"]] # 处理 tokenization 问题
        stop_criteria = transformers.StoppingCriteriaList([StoppingCriteriaICL(stop_ids)])

        model = transformers.pipeline(
            "text-generation",
            model=args.llm_id,
            torch_dtype=torch.float16,
            device_map="auto",
            max_new_tokens=512,
            return_full_text=False,
            stopping_criteria=stop_criteria
        )
        f_llm_generation = llama_generation
    elif 'gpt' in args.llm_type:
        tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

        if args.llm_type == 'gpt3.5':
            f_form_prompt = form_prompt_for_gpt3
            model = openai.Completion.create
        else:
            f_form_prompt = form_prompt_for_chatgpt
            model = openai.ChatCompletion.create

        f_llm_generation = gpt_generation
    else:
        raise NotImplementedError

    # 处理每个验证集示例,生成和解析输出
    for val_example in tqdm(val_example_list, total=len(val_example_list), desc='test'):
        top_k = args.K
        prompt_for_gpt = f_form_prompt(
            text_input=val_example['prompt'],
            top_k=top_k,
            tokenizer=tokenizer,
            supporting_examples=supporting_examples,
            features=features
        )
        if args.verbose:
            print(prompt_for_gpt)
            print('\n' + '-'*30)

        while True:
            try:
                response, raw_response = f_llm_generation(prompt_for_gpt, model, args, eos_token_id=tokenizer.eos_token_id)
                break
            except openai.error.ServiceUnavailableError:
                print('OpenAI ServiceUnavailableError.\tWill try again in 5 seconds.')
                time.sleep(5)
            except openai.error.RateLimitError:
                print('OpenAI RateLimitError.\tWill try again in 5 seconds.')
                time.sleep(5)
            except openai.error.InvalidRequestError as e:
                print(e)
                print('Input too long. Will shrink the prompting examples.')
                top_k -= 1
                prompt_for_gpt = f_form_prompt(
                    text_input=val_example['prompt'],
                    top_k=top_k,
                    supporting_examples=supporting_examples,
                    features=features
                )
            except RuntimeError as e:
                if "out of memory" in str(e):
                    top_k -= 1
                    prompt_for_gpt = f_form_prompt(
                        text_input=val_example['prompt'],
                        top_k=top_k,
                        tokenizer=tokenizer,
                        supporting_examples=supporting_examples,
                        features=features
                    )
                else:
                    raise e

        all_responses.append(response)
        for i_iter in range(args.n_iter):
            # 解析输出
            predicted_object_list = []
            line_list = response[i_iter].split('\n')
                
            for line in line_list:
                if line == '':
                    continue
                try:
                    selector_text, bbox = parse_layout(line, canvas_size=args.canvas_size)
                    if selector_text == None:
                        print(line)
                        continue
                    predicted_object_list.append([selector_text, bbox])
                except ValueError as e:
                    pass
            all_prediction_list.append({
                'query_id': val_example['id'],
                'iter': i_iter,
                'prompt': val_example['prompt'],
                'object_list': predicted_object_list,
            })

    # 保存输出
    with open(output_filename, 'w') as fout:
        json.dump(all_prediction_list, fout, indent=4, sort_keys=True)
    print(f'LayoutGPT ({args.llm_type}) prediction results written to {output_filename}')

初始化一个空列表 message_list,用来存储构成最终提示的各个消息部分。

 message_list = []

构造一个系统提示,详细说明如何生成图像布局的指令,包括布局的 CSS 样式和尺寸限制,其中 args.canvas_size 被用来指定图像的尺寸。

    system_prompt = 'Instruction: Given a sentence prompt that will be used to generate an image, plan the layout of the image.' \
                    'The generated layout should follow the CSS style, where each line starts with the object description ' \
                    'and is followed by its absolute position. ' \
                    'Formally, each line should be like "object {{width: ?px; height: ?px; left: ?px; top: ?px; }}". ' \
                    'The image is {}px wide and {}px high. ' \
                    'Therefore, all properties of the positions should not exceed {}px, ' \
                    'including the addition of left and width and the addition of top and height. \n'.format(args.canvas_size, args.canvas_size, args.canvas_size)

将系统提示作为字典添加到 message_list 中,角色为 'system',内容是之前创建的 system_prompt

    message_list.append({'role': 'system', 'content': system_prompt})

创建一个包含用户输入和布局起始关键字的提示字符串。

    final_prompt = f'Prompt: {text_input}\nLayout:'

计算系统提示和用户输入提示的总长度,用于确保生成的提示不会超过模型的输入长度限制。

    total_length = len(tokenizer(system_prompt + final_prompt)['input_ids'])


 

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值