[论文阅读笔记34] LISA (LISA: Reasoning Segmentation via Large Language Model) 代码精读

9 篇文章 0 订阅
9 篇文章 0 订阅

LISA是一个很好的Reason Segmentation的baseline, 其利用特殊的token [SEG]来微调多模态LLM和SAM的decoder来实现复杂逻辑下的prompt的推理分割. 其整体框图如下, 本篇文章精度此代码并作简单复现.

在这里插入图片描述


1. 推理流程

流程如下:

1.1 加载Tokenizer与模型

首先利用transformers库的AutoTokenizer从config文件中加载Tokenizer. LISA使用的Tokenizer是LLaMa相同的Tokenizer,填充方式是在序列右侧填充,最大长度是512:

tokenizer = AutoTokenizer.from_pretrained(
        args.version,
        cache_dir=None,
        model_max_length=args.model_max_length,  # 512
        padding_side="right",
        use_fast=False,  # 如果选择true 则模型自动进行填充 选择False的流程是首先对输入进行encode 然后再填充
    )

然后设定填充的unknown token以及本工作提出的代表分割掩码的[SEG] token:

tokenizer.pad_token = tokenizer.unk_token  # unknown token为 <unk>
args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]  # 在codebook中 [SEG]的id是32000
# tokenizer返回一个dict, 有两个key, input_ids表示在词汇表中的索引, attention_mask表示是否为填充的token 用于在注意力计算的时候看哪些token应该被注意

然后根据tokenizer的设定,设置model的句子开始、结束以及填充的token id:

model.config.eos_token_id = tokenizer.eos_token_id  # </s>
model.config.bos_token_id = tokenizer.bos_token_id  # <s>
model.config.pad_token_id = tokenizer.pad_token_id  # <unk>

随后,加载CLIP预训练的ViTal-large模型,作为LLaVA中的vision encoder:

model.get_model().initialize_vision_modules(model.get_model().config)
vision_tower = model.get_model().get_vision_tower()
vision_tower.to(dtype=torch_dtype)

加载CLIP的图像预处理类,包含resize、crop等,以及SAM中用到的resize类,其按照图像的最长边进行等比例resize:

clip_image_processor = CLIPImageProcessor.from_pretrained(model.config.vision_tower)
transform = ResizeLongestSide(args.image_size)

1.2 推理主要过程

按照LLaVA规定的格式先实例化一个Conversation类,这个类是自定义的一个数据类,用以保存所有的对话历史:

conv = conversation_lib.conv_templates[args.conv_type].copy()
"""
规定一开始的system设定 角色(用户和bot) 对话历史 分隔的token等
Conversation(system="A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=('USER', 'ASSISTANT'), messages=(), offset=0, sep_style=<SeparatorStyle.TWO: 2>, sep=' ', sep2='</s>', version='v1', skip_next=False)
"""
conv.messages = []  # 初始化

读取文本prompt, 例如who is the oldest person? Please output segmentation mask.

读取文本后,要在文本prompt之前加入给图像预留的token <image>, 并且在<image>前后加入起止符:

prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)  # <im_start><image><im_end>\nwho is the oldest person? Please output segmentation mask.

随后组合成完整的对话prompt:

A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <im_start><image><im_end>\nwho is the oldest person? Please output segmentation mask. ASSISTANT:

读取图片,并用CLIP的image preprocess将图像缩放为224x224, 并遵循SAM的预处理将图像长边缩放至1024, 并填充至1024x1024. 两种预处理分别对应两个vision encoder.

随后将文本prompt进行tokenize, 得到input_ids:

input_ids = tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
input_ids = input_ids.unsqueeze(0).cuda()  # [bs, length]
"""
tensor([[    1,   319, 13563,  1546,   263, 12758,  5199,   322,   385, 23116,
         21082, 20255, 29889,   450, 20255,  4076,  8444, 29892, 13173, 29892,
           322,  1248,   568,  6089,   304,   278,  5199, 29915, 29879,  5155,
         29889,  3148,  1001, 29901, 32001,  -200, 32002,  1058,   338,   278,
         23947,  2022, 29973,  3529,  1962, 10768,   362, 11105, 29889,   319,
          1799,  9047, 13566, 29901]], device='cuda:0')
"""

将input_ids输入LISA模型, 实际上走的是LLaVA, 并产生文本输出, 类似于训练时预规定的"Sure, It is [SEG]":

with torch.no_grad():
    outputs = self.generate(  # transformers库的方法 采用greedy生成 即每次选logits最大的token输出
        images=images_clip,
        input_ids=input_ids,
        max_new_tokens=max_new_tokens,
        num_beams=1,
        output_hidden_states=True,  # 输出hidden state 是因为我们要取[SEG]对应的embedding来decode分割的mask
        return_dict_in_generate=True,
    )
"""
输出的句子(outputs.sequence)为:
tensor([[    1,   319, 13563,  1546,   263, 12758,  5199,   322,   385, 23116,
         21082, 20255, 29889,   450, 20255,  4076,  8444, 29892, 13173, 29892,
           322,  1248,   568,  6089,   304,   278,  5199, 29915, 29879,  5155,
         29889,  3148,  1001, 29901, 32001,  -200, 32002,  1058,   338,   278,
         23947,  2022, 29973,  3529,  1962, 10768,   362, 11105, 29889,   319,
          1799,  9047, 13566, 29901, 18585, 29892, 32000,   869,     2]],
       device='cuda:0')
翻译过来就是:
<s>A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <im_start> <im_end> who is the oldest person? Please output segmentation mask. ASSISTANT: Sure, [SEG] .</s>
"""

output_hidden_states = outputs.hidden_states[-1]  # 取最后一层的hidden state [bs, 313, 5120]
output_ids = outputs.sequences

之后是重要的步骤: 把[SEG]在输出中对应的位置取出来(得到一个mask, 只有[SEG]在的位置是1), 并在前面填充255个False, 原因是:在LLaVA的推理过程中, vision encoder会将224x224的图像切分成14x14大小的patch, 共256个, 所以等价的输出也是256个, 因此在真实的LLaVA输出中, 长度比原本多了256 - 1(<image>符).

eg_token_mask = output_ids[:, 1:] == self.seg_token_idx
# hack for IMAGE_TOKEN_INDEX (we suppose that there is only one image, and it is in the front)
seg_token_mask = torch.cat(
    [
        torch.zeros((seg_token_mask.shape[0], 255)).bool().cuda(),
        seg_token_mask,
    ],
    dim=1,
)

取出[SEG]对应的hidden state, 并经过一个MLP层用以对齐:

hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states))

last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
pred_embeddings = last_hidden_state[seg_token_mask]

然后是SAM阶段, 对SAM预处理的image经过encoder, 然后将[SEG]对应的hidden state作为prompt, 也进行encode:

image_embeddings = self.get_visual_embs(images)  # [1, 256, 64, 64] 其中 64 = 1024 / 16 patch大小为16 x 16

multimask_output = False
pred_masks = []
for i in range(len(pred_embeddings)):
    (
        sparse_embeddings,  # [1, 1, 256]
        dense_embeddings,  # [1, 256, 64, 64]
    ) = self.model.visual_model.prompt_encoder(
        points=None,
        boxes=None,
        masks=None,
        text_embeds=pred_embeddings[i].unsqueeze(1),
    )

随后执行常规的SAM decoder过程, 得到分割掩码. 最后, 对output_id进行解码得到文本输出, 以及对mask进行可视化与保存即可.

low_res_masks, iou_predictions = self.model.visual_model.mask_decoder(
    image_embeddings=image_embeddings[i].unsqueeze(0),
    image_pe=self.model.visual_model.prompt_encoder.get_dense_pe(),  # 图像的位置编码
    sparse_prompt_embeddings=sparse_embeddings,
    dense_prompt_embeddings=dense_embeddings,
    multimask_output=multimask_output,  # False, LISA只默认分割一个物体
)
pred_mask = self.model.visual_model.postprocess_masks(  # 将低分辨率mask [224, 224]后处理回原本的分辨率
    low_res_masks,
    input_size=resize_list[i],
    original_size=original_size_list[i],
)

text_output = tokenizer.decode(output_ids, skip_special_tokens=False)

1.3 运行结果

可以看出来,它并没有很正确地分割出来最“老”的人(应该是画面最右侧的), 而是倾向于分割所有的人, 说明LISA可能存在对特定文本忽视的现象, 这在一些生成的工作中有人关注过.

2. 训练流程

2.1 数据准备与读取

LISA的训练数据包括四个任务:

  • 语义分割
  • 指令分割
  • VQA
  • 推理分割

在训练的时候, 将四种任务的若干数据集混合, 封装成统一的HybridDataset类, 在每次迭代的时候, 都随机从四个任务中挑选一个任务, 再从挑选的任务中随机选一个数据集, 再从这个数据集中随机选一个样本, 代码如下:

# HybridDataset:
def __getitem__(self, idx):
    ind = np.random.choice(list(range(len(self.datasets))), p=self.sample_rate)
    data = self.all_datasets[ind]
    inference = False
    return *data[0], inference
    
# 传到单个任务的数据集类中索引是0 但是也是随机选一个数据集之后随机选一个样本 例如对应Refferring Segment:
def __getitem__(self, idx):
    ds = random.randint(0, len(self.refer_seg_ds_list) - 1)
    ds = self.refer_seg_ds_list[ds]
    refer_seg_ds = self.refer_seg_data[ds]
    images = refer_seg_ds["images"]
    annotations = refer_seg_ds["annotations"]
    img2refs = refer_seg_ds["img2refs"]
    idx = random.randint(0, len(images) - 1)

接下来看一下每个任务的数据集是如何构建以及读取的.

2.2.1 HyBridDataset初始化

在训练的主函数(train_ds.py)中, 对混合数据集进行如下初始化:

train_dataset = HybridDataset(
    args.dataset_dir,  # 根目录 存放所有任务的所有数据集
    tokenizer,  # 采用LLaVA的tokenizer
    args.vision_tower,  # CLIP的ViT-large
    samples_per_epoch=args.batch_size  # 一个epoch的样本数 = bs * 梯度积累步数 * 一个epoch的步数 * 显卡数
    * args.grad_accumulation_steps  # 默认10 梯度累积的主要目的是在显存有限的情况下, 模拟大bs的训练效果
    * args.steps_per_epoch  # 默认500
    * world_size,
    precision=args.precision,  # 推理时候的精度 fp16/fp32
    image_size=args.image_size,  # 默认1024 SAM的输入分辨率
    num_classes_per_sample=args.num_classes_per_sample,  # 一个样本的标注中最多看几个类别 默认为3
    exclude_val=args.exclude_val,  # 是否排除验证集
    dataset=args.dataset,  # 默认四个任务都进行
    sample_rate=[float(x) for x in args.sample_rates.split(",")],  # 对每个任务的采样频率 
    # 默认语义分割: 指令分割: VQA: 因果分割 = 9: 3: 3: 1 可以看出是保住分割能力 并防止复杂prompt的过拟合
    sem_seg_data=args.sem_seg_data,  # 具体的语义分割的数据集名称
    refer_seg_data=args.refer_seg_data,  # 具体的指令分割的数据集名称
    vqa_data=args.vqa_data,  # 具体的VQA的数据集名称
    reason_seg_data=args.reason_seg_data,  # 具体的推理分割的数据集名称
    explanatory=args.explanatory,  # 这个参数是对ReasonSeg而言的, 问题是要求解释的问题("例如Please segment.. and explain why")的比例, 默认是0.1. 加入VQA数据集训练的目的也是保障模型回答问题的能力.
)
2.2.2 语义分割

在初始化函数中, 分别定义图像预处理方法, 问题模板, 回答模板以及每个数据集的样本的class名称, 图像以及label:

self.transform = ResizeLongestSide(image_size)
self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower)

self.short_question_list = SHORT_QUESTION_LIST  # 例如 <image> + "\n" + "Can you segment the {class_name} in this image?"
self.answer_list = ANSWER_LIST  # 例如 "It is [SEG]." "Sure, [SEG]."等

self.data2list = {}  # key: 数据集 value: (images, labels), images 和 labels为长度为N的列表, 里面存储路径
self.data2classes = {}  # key: 数据集 value: class名称的np.ndarray

# 存储每一个数据集
self.sem_seg_datas = sem_seg_data.split("||")
for ds in self.sem_seg_datas:
    classes, images, labels = eval("init_{}".format(ds))(base_image_dir)
    self.data2list[ds] = (images, labels)
    self.data2classes[ds] = classes

__getitem__中, 首先随机选择数据集, 然后再从数据集中随机选择一个样本进行图像和label的读取和resize. 此外, 读取对应的类别(如果超过规定的数目, 就随机抽取args.num_classes_per_sample个), 这部分不再赘述.

之后, 构建问题和答案:

questions = []
answers = []
class_ids = []  # 为样本中的每一个类别创建一组问答
for sampled_cls in sampled_classes:
    text = sampled_cls

    assert len(text.split("||")) == 1
    question_template = random.choice(self.short_question_list)  # 按照模板构建问题
    questions.append(question_template.format(class_name=text.lower()))

    answers.append(random.choice(self.answer_list))  # 随机选择答案模板

    if ds in ["paco_lvis", "pascal_part"]:  # 这两个数据集是single class 特殊处理
        continue

    class_id = self.data2classes[ds].tolist().index(sampled_cls)
    class_ids.append(class_id)

# 转换为标准的prompt 即 system: A chat... Human: XXX Assistant: XXX
conversations = []
conv = conversation_lib.default_conversation.copy()

i = 0
while i < len(questions):
    conv.messages = []
    conv.append_message(conv.roles[0], questions[i])
    conv.append_message(conv.roles[1], answers[i])
    conversations.append(conv.get_prompt())
    i += 1

随后读取label中的mask, 这部分省略, 返回值是如下的格式, 其余数据集也遵循:

return (
    image_path,  # 图像路径
    image,  # 用于SAM的resize图像 应该是1024x1024
    image_clip,  # 用于CLIP的resize图像 应该是224x224
    conversations,  # 真值conversation
    masks,  # 真值masks shape: [n, h, w] n是对应的物体类别数 [h, w]是原大小
    label,  # [h, w], 原始分割标签
    resize,  # [1024, 1024]
    questions,  # 问题
    sampled_classes,  # 抽取的类别名称 list
)
2.2.3 指令分割

基本流程和语义分割是相似的, 只不过class name需要从annotation的referring中读出来:

img2refs = refer_seg_ds["img2refs"]
refs = img2refs[image_id]  # 得到图像对应的referrings

# 读取对应所有referring的文本 当然后面要根据args.num_classes_per_sample作筛选
sents = []
ann_ids = []
for ref in refs:
    for sent in ref["sentences"]:
        text = sent["sent"]
        sents.append(text)
        ann_ids.append(ref["ann_id"])
        
# 因此
sampled_classes = sampled_sents

# 后面读取mask也类似, 要根据抽取出的referring找到对应的mask
2.2.4 VQA

VQA比较特殊. 直接从数据集中读取数据即可构建conversation, 对于mask和label, 则将mask置为全0, label都置为ignore_label(255):

conversations = []
if roles[source[0]["from"]] != conv.roles[0]:
    # Skip the first one if it is not from human
    source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
    role = roles[sentence["from"]]  # 直接从数据集读取
    assert role == conv.roles[j % 2], f"{i}"
    conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())

questions = conversations
sampled_classes = conversations

image = self.preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous())

masks = torch.rand(0, *ori_size)  # 全0
label = torch.ones(ori_size) * self.ignore_label  # 全255
2.2.5 ReasonSeg

ReasonSeg最大的不同就是要处理长问话以及解释性的问话,

首先读取当前随机抽取样本的mask, 问话以及是否为一个句子:

mask, sents, is_sentence = get_mask_from_json(json_path, image)

随后看是否为需要解释的样本, 如果是的话就构建对应的问话. 其中的choice是在这种情况下, 进一步控制是解释性问题的比例, 其实是让解释性问题的占比进一步降低了.

if is_sentence:
    question_template = random.choice(self.long_question_list)
    questions.append(question_template.format(sent=text))
else:
    question_template = random.choice(self.short_question_list)
    questions.append(question_template.format(class_name=text.lower()))

# add explanation if applicable
img_name = image_path.split("/")[-1]
if self.explanatory != -1 and img_name in self.img_to_explanation:
    if choice == 0:  # [SEG] token  # 最简单的回答
        answers.append(random.choice(self.answer_list))
    elif choice == 1:  # [SEG] token + text answer  # 否则加入解释性的提问
        image_name = image_path.split("/")[-1]
        answer = self.img_to_explanation[image_name]["outputs"]
        answer = random.choice(self.answer_list) + " {}".format(answer)
        questions[-1] = (
            DEFAULT_IMAGE_TOKEN
            + "\n"
            + text
            + " {}".format(random.choice(self.explanatory_question_list))
        )
        answers.append(answer)
    elif choice == 2:  # vanilla text answer  # 不加入
        image_name = image_path.split("/")[-1]
        answer = self.img_to_explanation[image_name]["outputs"]
        questions[-1] = DEFAULT_IMAGE_TOKEN + "\n" + text
        answers.append(answer)
    else:
        raise ValueError("Not implemented yet.")
else:
    answers.append(random.choice(self.answer_list))

2.2 单步训练流程与损失计算

2.2.1 模型载入

模型载入和推理过程基本是相似的. 但是训练过程中需要用LoRA来微调LLaVA (要训练生成固定的答案模板), 具体做法如下:

lora_r = args.lora_r  # 降到秩为多少 默认为8
if lora_r > 0:
	
    # 查找模型中符合条件的线性层 并保存下来
    def find_linear_layers(model, lora_target_modules):
        cls = torch.nn.Linear
        lora_module_names = set()
        for name, module in model.named_modules():
            if (
                isinstance(module, cls)
                and all(
                    [
                        x not in name
                        for x in [
                            "visual_model",
                            "vision_tower",
                            "mm_projector",
                            "text_hidden_fcs",
                        ]
                    ]
                )
                and any([x in name for x in lora_target_modules])
            ):
                lora_module_names.add(name)
        return sorted(list(lora_module_names))

    lora_alpha = args.lora_alpha  # lora超参
    lora_dropout = args.lora_dropout  # lora超参
    lora_target_modules = find_linear_layers(
        model, args.lora_target_modules.split(",")
    )
    # 配置config 利用peft库实现lora
    lora_config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        target_modules=lora_target_modules,
        lora_dropout=lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
    )
    # 根据lora要改变的线性层更新模型
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
2.2.2 前向传播

在每次迭代中, 首先用SAM的image encoder得到图像特征, 以及, 得到[SEG]在tokenize的输入中的位置:

image_embeddings = self.get_visual_embs(images)  # [bs, c, h, w]
batch_size = image_embeddings.shape[0]
assert batch_size == len(offset) - 1

seg_token_mask = input_ids[:, 1:] == self.seg_token_idx  # [bs, N]
seg_token_mask = torch.cat(
    [
        seg_token_mask,
        torch.zeros((seg_token_mask.shape[0], 1)).bool().cuda(),  # [bs, 1]
    ],
    dim=1,
)  # [bs, N + 1]
# hack for IMAGE_TOKEN_INDEX (we suppose that there is only one image, and it is in the front)
# 补齐255个0的理由和推理节介绍的相同
seg_token_mask = torch.cat(
    [torch.zeros((seg_token_mask.shape[0], 255)).bool().cuda(), seg_token_mask],
    dim=1,
)

然后输入LLaVA:

images_clip_list = []
for i in range(len(offset) - 1):
    start_i, end_i = offset[i], offset[i + 1]  # 该样本具有多少annotation, 即问答对的起始和最终的idx
    images_clip_i = (
        images_clip[i]
        .unsqueeze(0)
        .expand(end_i - start_i, -1, -1, -1)  # 就重复这么多遍
        .contiguous()
    )
    images_clip_list.append(images_clip_i)
images_clip = torch.cat(images_clip_list, dim=0)

output = super().forward(  # 得到LLaVA的结果
    images=images_clip,
    attention_mask=attention_masks,
    input_ids=input_ids,
    labels=labels,
    output_hidden_states=True,
)
output_hidden_states = output.hidden_states

得到最后一层的各个batch中[SEG]的embedding:

hidden_states = []

assert len(self.model.text_hidden_fcs) == 1
hidden_states.append(self.model.text_hidden_fcs[0](output_hidden_states[-1]))  # 输入FC层对齐SAM和LLavA

last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
pred_embeddings = last_hidden_state[seg_token_mask]
seg_token_counts = seg_token_mask.int().sum(-1)  # [bs, ]

# 得到每个batch中seg token的起始位置 并获得相应的embeddings
seg_token_offset = seg_token_counts.cumsum(-1)
seg_token_offset = torch.cat(
    [torch.zeros(1).long().cuda(), seg_token_offset], dim=0
)

seg_token_offset = seg_token_offset[offset]

pred_embeddings_ = []
for i in range(len(seg_token_offset) - 1):
    start_i, end_i = seg_token_offset[i], seg_token_offset[i + 1]
    pred_embeddings_.append(pred_embeddings[start_i:end_i])
pred_embeddings = pred_embeddings_

遍历每个embeddings, 用SAM的decoder得到mask, 这部分和推理过程相似, 不再赘述. 最后计算loss. 一个是训练LLaVA用的交叉熵损失, 用于文本输出和模板一致; 另外就是分割常用的bce和dice loss. 注意对于VQA任务的样本, mask loss理应是0.

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值