# 导入所需的库和模块
import os
import os.path as op
import json
import pdb
import clip
import torch
import numpy as np
from tqdm import tqdm
import time
import random
import argparse
import openai
from transformers import GPT2TokenizerFast, LlamaForCausalLM, LlamaTokenizer
import transformers
from utils import *
# 配置 OpenAI API,这通常用于调用 OpenAI 提供的模型
openai.organization = ""
openai.api_key = ""
# 定义不同的语言模型和它们的 API 路径
llm_name2id = {
'llama-7b': 'meta-llama/Llama-2-7b-hf',
'llama-13b': 'meta-llama/Llama-2-13b-hf',
'llama-70b': 'meta-llama/Llama-2-70b-hf',
'gpt3.5': 'text-davinci-003',
'gpt3.5-chat': 'gpt-3.5-turbo',
'gpt4': 'gpt-4',
}
# 设置命令行参数解析器,用于从用户获取配置选项
parser = argparse.ArgumentParser(prog='LayoutGPT: text-based image layout planning', description='Use LayoutGPT to generate image layouts.')
parser.add_argument('--input_info_dir', type=str, default='./dataset/NSR-1K')
parser.add_argument('--base_output_dir', type=str, default='./llm_output')
parser.add_argument('--setting', type=str, default='counting', choices=['counting', 'spatial'])
parser.add_argument('--matching_content_type', type=str, default='visual')
parser.add_argument('--llm_type', type=str, default='gpt4', choices=list(llm_name2id.keys()))
parser.add_argument('--icl_type', type=str, default='k-similar', choices=['fixed-random', 'k-similar'])
parser.add_argument('--K', type=int, default=8)
parser.add_argument('--gpt_input_length_limit', type=int, default=3000)
parser.add_argument('--canvas_size', type=int, default=256)
parser.add_argument("--n_iter", type=int, default=1)
parser.add_argument("--test", action='store_true')
parser.add_argument('--verbose', default=False, action='store_true')
args = parser.parse_args()
# 根据参数类型加载不同的特征和模型
if args.icl_type == 'k-similar':
# 初始化和加载 CLIP 模型
clip_feature_name = 'ViT-L/14'.lower().replace('/', '-')
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load('ViT-L/14', device=device)
clip_model = clip_model.to(device)
# 加载特征数据,用于计算文本与图像间的相似度
def load_features(content):
"""Load visual/text features from npz file"""
np_filename = os.path.join(
args.input_info_dir, args.setting,
f'train.{args.setting}.{clip_feature_name}.{content}.npz',
)
feature_list = np.load(np_filename)['feature_list']
features = torch.HalfTensor(feature_list).to(device)
features /= features.norm(dim=-1, keepdim=True)
return features
# 生成图像布局的提示文本
def create_exemplar_prompt(caption, object_list, canvas_size, is_chat=False):
if is_chat:
prompt = ''
else:
prompt = f'\nPrompt: {caption}\nLayout:\n'
for obj_info in object_list:
category, bbox = obj_info
coord_list = [int(i*canvas_size) for i in bbox]
x, y, w, h = coord_list
prompt += f'{category} {{height: {h}px; width: {w}px; top: {y}px; left: {x}px; }}\n'
return prompt
# 根据用户输入和模型类型构造适当的提示,以供大型语言模型生成布局
def form_prompt_for_chatgpt(text_input, top_k, tokenizer, supporting_examples=None, features=None):
# 创建初始提示和配置,后续根据模型类型和选择继续配置提示
message_list = []
system_prompt = 'Instruction: Given a sentence prompt that will be used to generate an image, plan the layout of the image.' \
'The generated layout should follow the CSS style, where each line starts with the object description ' \
'and is followed by its absolute position. ' \
'Formally, each line should be like "object {{width: ?px; height: ?px; left: ?px; top: ?px; }}". ' \
'The image is {}px wide and {}px high. ' \
'Therefore, all properties of the positions should not exceed {}px, ' \
'including the addition of left and width and the addition of top and height. \n'.format(args.canvas_size, args.canvas_size, args.canvas_size)
message_list.append({'role': 'system', 'content': system_prompt})
final_prompt = f'Prompt: {text_input}\nLayout:'
total_length = len(tokenizer(system_prompt + final_prompt)['input_ids'])
if args.icl_type == 'k-similar':
# 使用 CLIP 模型计算相似度,并选择最相关的支持例子
text_inputs = clip.tokenize(text_input, truncate=True).to(device)
text_features = clip_model.encode_text(text_inputs)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * text_features @ features.T).softmax(dim=-1)
_, indices = similarity[0].topk(top_k)
supporting_examples = [supporting_examples[idx] for idx in indices]
# 其余部分为对模型生成的响应进行处理和格式化,然后将最终结果输出到文件中
# 遍历相关支持示例,检查生成的提示长度是否超限
for supporting_example in supporting_examples:
# 根据设置选择适当的示例格式
if args.setting == 'counting':
current_prompting_example = create_exemplar_prompt(
caption=supporting_example['prompt'],
object_list=supporting_example['object_list'],
canvas_size=args.canvas_size,
)
else:
current_prompting_example = create_exemplar_prompt(
caption=supporting_example['prompt'],
object_list=[supporting_example['obj1'], supporting_example['obj2']],
canvas_size=args.canvas_size,
)
# 计算当前提示示例的长度
cur_len = len(tokenizer(current_prompting_example)['input_ids'])
# 如果添加这个示例后总长度超过限制,则停止添加
if total_length + cur_len > args.gpt_input_length_limit:
break
# 将当前示例加到提示字符串的开始,保证最相关的示例最先展示
prompting_examples = current_prompting_example + prompting_examples
# 更新总长度
total_length += cur_len
# 将所有示例和最终用户输入合并成一个完整的提示
prompting_examples += last_example
rtn_prompt += prompting_examples
# 返回构造的完整提示字符串
return rtn_prompt
class StoppingCriteriaICL(transformers.StoppingCriteria):
def __init__(self, stops=[],) -> None:
super().__init__()
# 将停止标记列表转移到 GPU,以便快速比较
self.stops = [s.to('cuda') for s in stops]
def __call__(self, input_ids, scores, **kwargs):
# 检查输入中是否存在停止标记
for stop in self.stops:
if torch.all(stop == input_ids[0][-len(stop):]):
return True
return False
这行定义了一个函数 create_exemplar_prompt
,接受四个参数:caption
(图像的描述性文本),object_list
(图像中对象的列表),canvas_size
(画布大小,影响布局尺寸),和 is_chat
(布尔值,控制输出格式是否适用于聊天模式)。
def create_exemplar_prompt(caption, object_list, canvas_size, is_chat=False):
这部分是一个条件判断。如果 is_chat
为 True
,则 prompt
初始化为空字符串。如果为 False
,则 prompt
初始化为包含标题和布局关键字的格式化字符串,即在提示文本后开始一个新的布局描述。
if is_chat:
prompt = ''
else:
prompt = f'\nPrompt: {caption}\nLayout:\n'
这行开始一个循环,遍历 object_list
中的每个对象,每个对象的信息存储在 obj_info
变量中。
这行将 obj_info
解包为两个变量:category
(对象的类别)和 bbox
(对象在图像中的边界框坐标)。这行使用列表推导式计算边界框的实际像素坐标。bbox
坐标乘以 canvas_size
后转换为整数,得到在画布上的具体位置。将 coord_list
解包为四个变量:x
(左上角横坐标),y
(左上角纵坐标),w
(宽度),h
(高度)。向 prompt
字符串追加格式化文本,描述每个对象的类别和其在画布上的具体位置和大小。
for obj_info in object_list:
category, bbox = obj_info
coord_list = [int(i*canvas_size) for i in bbox]
x, y, w, h = coord_list
prompt += f'{category} {{height: {h}px; width: {w}px; top: {y}px; left: {x}px; }}\n'
return prompt
定义一个函数,用来为聊天式 GPT 模型生成适当的提示,接收参数:text_input
(用户的文本输入),top_k
(选择相关度最高的 k 个例子),tokenizer
(用于文本分词的工具),supporting_examples
(支持性例子列表),features
(特征向量,用于计算相似度)。
这个类是 transformers.StoppingCriteria
的一个子类,定义了自定义的停止条件,用于文本生成过程。
def llama_generation(prompt_for_llama, model, args, eos_token_id=2, stop_criteria=None):
# 存储生成的响应
responses = []
# 根据指定的迭代次数生成多个响应
for _ in range(args.n_iter):
responses += model(
prompt_for_llama,
do_sample=True,
num_return_sequences=1,
eos_token_id=eos_token_id,
temperature=0.7,
)
# 提取并返回生成的文本和原始响应
response_text = [r['generated_text'] for r in responses]
return response_text, responses
这个函数用于使用 LLaMA 模型进行文本生成。
def gpt_generation(prompt_for_gpt, f_gpt_create, args, **kwargs):
# 设置调用模型的参数
input_kwargs = {
"model": args.llm_id,
"temperature": 0.7,
"max_tokens": 256,
"top_p": 1.0,
"frequency_penalty": 0.0,
"presence_penalty": 0.0,
"stop": "Prompt:",
"n": args.n_iter,
}
# 根据模型类型调整输入格式
if args.llm_type == 'gpt3.5':
input_kwargs["prompt"] = prompt_for_gpt
else:
input_kwargs["messages"] = prompt_for_gpt
# 调用模型生成文本
response = f_gpt_create(**input_kwargs)
# 根据模型类型处理并返回响应
if args.llm_type == 'gpt3.5':
response_text = [r["text"] for r in response.choices]
else:
response_text = [r["message"]["content"] for r in response.choices]
return response_text, response
这个函数用于使用 GPT 模型进行文本生成。
def _main(args):
# 检查输出目录是否已存在处理结果
args.output_dir = os.path.join(args.base_output_dir, args.setting)
os.makedirs(args.output_dir, exist_ok=True)
output_filename = os.path.join(args.output_dir, f'{args.llm_type}.{args.setting}.{args.icl_type}.k_{args.K}.px_{args.canvas_size}.json')
if os.path.exists(output_filename):
print(f'{output_filename} have been processed.')
return
# 加载验证集示例
val_example_files = os.path.join(
args.input_info_dir, args.setting,
f'{args.setting}.val.json',
)
val_example_list = json.load(open(val_example_files))
if args.test:
val_example_list = val_example_list[:3]
# 加载训练集示例
train_example_files = os.path.join(
args.input_info_dir, args.setting,
f'{args.setting}.train.json',
)
train_examples = json.load(open(train_example_files))
if args.icl_type == 'fixed-random':
random.seed(42)
random.shuffle(train_examples)
supporting_examples = train_examples[:args.K]
features = None
elif args.icl_type == 'k-similar':
supporting_examples = train_examples
features = load_features(args.matching_content_type)
# 根据模型类型选择合适的提示构造函数和生成函数
args.llm_id = llm_name2id[args.llm_type]
all_prediction_list = []
all_responses = []
if 'llama' in args.llm_type:
f_form_prompt = form_prompt_for_gpt3
tokenizer = LlamaTokenizer.from_pretrained(args.llm_id)
stop_ids = [tokenizer(w, return_tensors="pt", add_special_tokens=False).input_ids.squeeze()[1:] for w in ["\n\n"]] # 处理 tokenization 问题
stop_criteria = transformers.StoppingCriteriaList([StoppingCriteriaICL(stop_ids)])
model = transformers.pipeline(
"text-generation",
model=args.llm_id,
torch_dtype=torch.float16,
device_map="auto",
max_new_tokens=512,
return_full_text=False,
stopping_criteria=stop_criteria
)
f_llm_generation = llama_generation
elif 'gpt' in args.llm_type:
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
if args.llm_type == 'gpt3.5':
f_form_prompt = form_prompt_for_gpt3
model = openai.Completion.create
else:
f_form_prompt = form_prompt_for_chatgpt
model = openai.ChatCompletion.create
f_llm_generation = gpt_generation
else:
raise NotImplementedError
# 处理每个验证集示例,生成和解析输出
for val_example in tqdm(val_example_list, total=len(val_example_list), desc='test'):
top_k = args.K
prompt_for_gpt = f_form_prompt(
text_input=val_example['prompt'],
top_k=top_k,
tokenizer=tokenizer,
supporting_examples=supporting_examples,
features=features
)
if args.verbose:
print(prompt_for_gpt)
print('\n' + '-'*30)
while True:
try:
response, raw_response = f_llm_generation(prompt_for_gpt, model, args, eos_token_id=tokenizer.eos_token_id)
break
except openai.error.ServiceUnavailableError:
print('OpenAI ServiceUnavailableError.\tWill try again in 5 seconds.')
time.sleep(5)
except openai.error.RateLimitError:
print('OpenAI RateLimitError.\tWill try again in 5 seconds.')
time.sleep(5)
except openai.error.InvalidRequestError as e:
print(e)
print('Input too long. Will shrink the prompting examples.')
top_k -= 1
prompt_for_gpt = f_form_prompt(
text_input=val_example['prompt'],
top_k=top_k,
supporting_examples=supporting_examples,
features=features
)
except RuntimeError as e:
if "out of memory" in str(e):
top_k -= 1
prompt_for_gpt = f_form_prompt(
text_input=val_example['prompt'],
top_k=top_k,
tokenizer=tokenizer,
supporting_examples=supporting_examples,
features=features
)
else:
raise e
all_responses.append(response)
for i_iter in range(args.n_iter):
# 解析输出
predicted_object_list = []
line_list = response[i_iter].split('\n')
for line in line_list:
if line == '':
continue
try:
selector_text, bbox = parse_layout(line, canvas_size=args.canvas_size)
if selector_text == None:
print(line)
continue
predicted_object_list.append([selector_text, bbox])
except ValueError as e:
pass
all_prediction_list.append({
'query_id': val_example['id'],
'iter': i_iter,
'prompt': val_example['prompt'],
'object_list': predicted_object_list,
})
# 保存输出
with open(output_filename, 'w') as fout:
json.dump(all_prediction_list, fout, indent=4, sort_keys=True)
print(f'LayoutGPT ({args.llm_type}) prediction results written to {output_filename}')
初始化一个空列表 message_list
,用来存储构成最终提示的各个消息部分。
message_list = []
构造一个系统提示,详细说明如何生成图像布局的指令,包括布局的 CSS 样式和尺寸限制,其中 args.canvas_size
被用来指定图像的尺寸。
system_prompt = 'Instruction: Given a sentence prompt that will be used to generate an image, plan the layout of the image.' \
'The generated layout should follow the CSS style, where each line starts with the object description ' \
'and is followed by its absolute position. ' \
'Formally, each line should be like "object {{width: ?px; height: ?px; left: ?px; top: ?px; }}". ' \
'The image is {}px wide and {}px high. ' \
'Therefore, all properties of the positions should not exceed {}px, ' \
'including the addition of left and width and the addition of top and height. \n'.format(args.canvas_size, args.canvas_size, args.canvas_size)
将系统提示作为字典添加到 message_list
中,角色为 'system',内容是之前创建的 system_prompt
。
message_list.append({'role': 'system', 'content': system_prompt})
创建一个包含用户输入和布局起始关键字的提示字符串。
final_prompt = f'Prompt: {text_input}\nLayout:'
计算系统提示和用户输入提示的总长度,用于确保生成的提示不会超过模型的输入长度限制。
total_length = len(tokenizer(system_prompt + final_prompt)['input_ids'])