释放图像处理潜能:深入解析 PaLiGemma 模型的调整与部署

释放图像处理潜能:深入解析 PaLiGemma 模型的调整与部署

原文:Finetune and Deploy Custom PaLiGemma Model for your Image Tasks

简介

PaLiGemma 是一个开源的最先进模型,与其他产品一起在 Google I/O 2024 上发布,结合了 Google 开发的另外两个模型。基于 SigLIP 视觉模型和 Gemma 语言模型等开放组件,PaLiGemma 是一个灵活且轻量级的视觉-语言模型(VLM),灵感来自 PaLI-3。它支持多种语言,在接受图像和文本输入后生成文本输出。它旨在作为各种视觉-语言活动的模型,包括文本阅读、对象识别和分割、视觉问答以及为图像和短视频加上标题。

与其他在对象检测和分割方面遇到困难的 VLM 不同,特别是 OpenAI 的 GPT-4o、Google Gemini 和 Anthropic 的 Claude 3,PaLiGemma 提供了各种功能,并可以进行微调以在特定任务上提高性能。

在今天的博客中,我们将学习如何对 PaLiGemma 模型进行微调并将其部署到其中一个服务提供商上。在整个教程中,我们将使用 Roboflow 以期望的格式轻松访问数据集,使用 Kaggle 加载模型权重,最后使用 Azure 虚拟机。对于此任务,具有 NVIDIA T4 GPU 的 Colab 实例将足够。

学习目标

在本博客中,您将学到:

  • 关于 PaLiGemma 模型及其组件的知识。
  • 如何设置环境以微调 PaLiGemma。
  • 以 JSONL 格式准备数据的技术。
  • 下载和配置 PaLiGemma 模型权重的过程。
  • 微调 PaLiGemma 并保存微调后的模型的步骤。
  • 使用 Azure 虚拟机部署微调后的模型的策略。

目录

开始之前

在阅读本博客之前,您应该熟悉 Python 编程和大型语言模型(LLMs)的训练过程。虽然不是必需的,但在检查示例代码片段时,对 JAX(或类似技术如 Keras)有基本了解将是有益的。

此外,为了微调 PaLiGemma,我们将按照以下步骤进行:

  1. 安装所需的依赖项
  2. 下载任何图像数据集以符合 PaliGemma 的 JSONL 格式
  3. 从 Kaggle 下载预训练的 PaliGemma 权重和分词器
  4. 使用 JAX 对 PaLiGemma 进行微调
  5. 保存我们的模型以备后用
  6. 部署微调后的模型

步骤 1:安装和设置模型

A. PaliGemma 和 Kaggle 设置

对于首次使用者,我们必须通过 Kaggle 请求 PaLiGemma 访问权限并配置我们的 API 密钥,具体步骤如下。

  1. 登录或注册 Kaggle 帐号: 登录您的 Kaggle 帐号,如果没有帐号,则创建一个新帐号。
  2. 请求访问 PaliGemma: 转到 Kaggle 上的 PaLiGemma 模型页面,点击“请求访问”,完成同意表单并接受条款和条件。
  3. 生成 Kaggle API 密钥: 在 Kaggle 上的设置页面中点击“创建新令牌”以下载包含您的 API 凭据的 kaggle.json 文件。
  4. 将 Kaggle API 密钥添加到 Colab: 在 Colab 中,选择左侧窗格中的“Secrets”(🔑),添加您的 Kaggle 用户名和 API 密钥。将您的用户名存储在 KAGGLE_USERNAME 下,将 API 密钥存储在 KAGGLE_KEY 下。
  5. 安全存储凭据: 确保您的 Kaggle API 密钥安全存储,并仅在需要访问 Kaggle 数据集或模型时使用。

完成所有步骤后,设置环境变量如下所示。

import os
from google.colab import userdata
# 注意:`userdata.get` 是 Colab 的 API。如果您没有使用 Colab,请根据需要设置适当的环境变量或使您的凭据在 ~/.kaggle/kaggle.json 中可用
os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

B. 获取 big_vision 仓库和依赖项

为了微调 PaLiGemma 模型,我们将使用由 Google Research 维护的 big_vision 项目。以下代码可以在您的笔记本中安装仓库和相应的依赖项。

import os
import sys
# 如果在环境变量中找到 "COLAB_TPU_ADDR",则表示正在使用不受支持的远程 TPUs,将会报错
if "COLAB_TPU_ADDR" in os.environ:
  raise "It seems you are using Colab with remote TPUs which is not supported."
# 如果 python 不知道 big_vision 仓库,就会获取并安装此 notebook 需要的依赖
if not os.path.exists("big_vision_repo"):
  !git clone --quiet --branch=main --depth=1 \
     https://github.com/google-research/big_vision big_vision_repo
# 将 big_vision 代码添加到 python 的导入路径中
if "big_vision_repo" not in sys.path:
  sys.path.append("big_vision_repo")
# 安装缺失的依赖。假设 jax~=0.4.25 并且有 GPU 可用
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"

C. 导入 JAX 和相关依赖项

下面的代码将导入必要的框架,如 JAX,以完成模型设置。

import base64
import functools
import html
import io
import os
import warnings
import jax
import jax.numpy as jnp
import numpy as np
import ml_collections
import tensorflow as tf
import sentencepiece
from IPython.core.display import display, HTML
from PIL import Image
# 从 big_vision 导入模型定义
from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns
# 导入 big vision 实用工具
import big_vision.datasets.jsonl
import big_vision.utils
import big_vision.sharding
# 不让 TF 使用 GPU 或 TPU
tf.config.set_visible_devices([], "GPU")
tf.config.set_visible_devices([], "TPU")
backend = jax.lib.xla_bridge.get_backend()
print(f"JAX 版本:  {jax.__version__}")
print(f"JAX 平台: {backend.platform}")
print(f"JAX 设备数:  {jax.device_count()}")

步骤 2:选择适合任务的数据并以 JSONL 格式准备

对于使用 PaLiGemma 进行微调任务,我们需要将数据准备成 PaLiGemma JSONL 格式。您可能对这种格式不太熟悉,因为它不是常见的数据格式(如 YOLO)用于图像任务,但是 JSONL(JSON Lines)通常用于训练大型模型,因为它允许高效的逐行处理。下面是数据存储的 JSONL 格式示例。

{"name": "John Doe", "age": 30, "city": "New York"}
{"name": "Jane Smith", "age": 25, "city": "Los Angeles"}
{"name": "Sam Brown", "age": 22, "city": "Chicago"}

将数据准备成 JSONL 格式很容易,下面我提供了示例代码。

import json
import os
# 包含图像的目录
image_dir = '/path/to/images'
# 包含图像标签的字典
labels = {
    "image1.jpg": "label1",
    "image2.jpg": "label2",
    "image3.jpg": "label3"
}
# 创建包含图像路径和标签的字典列表
data = []
for image_name, label in labels.items():
    image_path = os.path.join(image_dir, image_name)
    data.append({"image_path": image_path, "label": label})
# 将数据写入 JSONL 文件
with open('images_labels.jsonl', 'w') as file:
    for entry in data:
        file.write(json.dumps(entry) + '\n')

然而,在这里我们将使用 Roboflow 来轻松完成任务。Roboflow 已经完全支持 PaLiGemma JSONL 格式,可以用于从 Roboflow Universe 访问任何数据集。您可以使用 Roboflow API 密钥根据任务要求使用任何数据集。下面是显示如何实现相同目的的代码片段。

# 安装下载和解析数据集所需的依赖项
!pip install roboflow supervision
from google.colab import userdata
from roboflow import Roboflow
ROBOFLOW_API_KEY = userdata.get('ROBOFLOW_API_KEY')
rf = Roboflow(api_key=ROBOFLOW_API_KEY)
project = rf.workspace("workspace-user-id").project("sample-project-name")
version = project.version(#enterversionnumber)
dataset = version.download("PaliGemma")

现在,我们已经成功完成了模型设置,并以所需的格式和平台导入了数据,我们可以获取 PaLiGemma 权重以进一步微调模型。

步骤 3:下载并配置 PaLiGemma 模型权重

这一步涉及从 Kaggle 下载 PaLiGemma 权重。为了在有限资源中进行简单计算,我们将使用 paligemma-3b-pt-224 版本。JAX/FLAX PaliGemma 3B 有三个不同版本,分别是输入图像分辨率(224、448 和 896)和输入文本序列长度(128、512 和 512 个标记)不同。

可以通过运行以下代码从 Kaggle 下载模型的 float16 版本的检查点。这个过程可能会花费一些时间。

import os
import kagglehub
MODEL_PATH = "./pt_224_128.params.f16.npz"
if not os.path.exists(MODEL_PATH):
  MODEL_PATH = kagglehub.model_download
  ('google/paligemma/jax/paligemma-3b-pt-224', 'paligemma-3b-pt-224.f16.npz')
  print(f"Model path: {MODEL_PATH}")
TOKENIZER_PATH = "./paligemma_tokenizer.model"
if not os.path.exists(TOKENIZER_PATPaLiGemma modelH):
  print("Downloading the model tokenizer...")
  !gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}
  print(f"Tokenizer path: {TOKENIZER_PATH}")
DATA_DIR="./longcap100"
if not os.path.exists(DATA_DIR):
  print("Downloading the dataset...")
  !gsutil -m -q cp -n -r gs://longcap100/ .
  print(f"Data path: {DATA_DIR}")

下一步需要配置并移动模型以适配 Colab T4 GPU。要设置模型,首先将 model_config 初始化为 FrozenConfigDict,这有助于冻结某些参数并减少内存使用。然后,创建 PaliGemma Model 类的实例,使用 model_config 进行设置。将模型参数加载到内存中,并定义一个解码函数,以从模型中对输出进行采样。完成后,模型可以移动到 T4 GPU。下面的代码将指导这两个步骤。

# 定义模型
model_config = ml_collections.FrozenConfigDict({
    "llm": {"vocab_size": 257_152},
    "img": {"variant": "So400m/14", "pool_type": "none", "scan": True,
     "dtype_mm": "float16"}
})
model = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)
# 加载参数 - 这在 T4 colab 中可能需要最多 1 分钟的时间。
params = paligemma.load(None, MODEL_PATH, model_config)
# 定义 `decode` 函数以从模型中对输出进行采样。
decode_fn = predict_fns.get_all(model)['decode']
decode = functools.partial(decode_fn, devices=jax.devices(), 
eos_token=tokenizer.eos_id())
# 将模型移动到 T4 GPU
# 创建可训练参数的 pytree mask。
def is_trainable_param(name, param):  # pylint: disable=unused-argument
  if name.startswith("llm/layers/attn/"):  return True
  if name.startswith("llm/"):              return False
  if name.startswith("img/"):              return False
  raise ValueError(f"Unexpected param name {name}")
trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)
# 如果有多个设备可用(例如多个 GPU),则可以将参数分片以减少每个设备的 HBM 使用。
mesh = jax.sharding.Mesh(jax.devices(), ("data"))
data_sharding = jax.sharding.NamedSharding(
    mesh, jax.sharding.PartitionSpec("data"))
params_sharding = big_vision.sharding.infer_sharding(
    params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh)
# 是的:一些捐赠的缓冲区无法使用。
warnings.filterwarnings(
    "ignore", message="Some donated buffers were not usable")
@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))
def maybe_cast_to_f32(params, trainable):
  return jax.tree.map(lambda p, m: p.astype(jnp.float32) if m else p,
                      params, trainable)
# 同时加载所有参数 - 尽管更快且更简洁 - 需要比 T4 colab 默认拥有的 RAM 更多。
# 相反,我们逐个参数进行加载。
params, treedef = jax.tree.flatten(params)
sharding_leaves = jax.tree.leaves(params_sharding)
trainable_leaves = jax.tree.leaves(trainable_mask)
for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, 
trainable_leaves)):
  params[idx] = big_vision.utils.reshard(params[idx], sharding)
  params[idx] = maybe_cast_to_f32(params[idx], trainable)
  params[idx].block_until_ready()
params = jax.tree.unflatten(treedef, params)
# 打印参数以显示模型的组成。
def parameter_overview(params):
  for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:
    print(f"{path:80s} {str(arr.shape):22s} {arr.dtype}")
print(" == Model params == ")
parameter_overview(params)

这一步已经完成了我们微调过程的所有必要步骤,因此我们可以继续进行下一步。

第四步:微调 PaLiGemma

在进行微调之前,必须执行一些额外的检查和预处理步骤。这些是标准程序,它们的代码会很长,因此不在当前范围内。这些细节可以在后续章节提到的其他开源资源中找到。不过,下面提到了这些步骤的概述。

  1. 创建模型输入
    • 通过将图像转换为灰度图像、去除 alpha 通道并将其调整为 224×224 像素来规范化图像数据。
    • 通过添加标记来标记标记是否为前缀或后缀的标记,以便在训练和评估过程中使用。
    • 删除序列结束(EOS)标记后的标记,并返回剩余的解码标记。
  2. 创建训练和验证迭代器
    • 定义训练迭代器以处理数据块、对示例进行洗牌并重复多个周期。使用适当的标记对图像进行预处理和对文本进行标记化。
    • 定义验证迭代器以有序地处理验证数据,对图像进行预处理并对文本进行标记化。
  3. 查看训练示例
  • 展示随机选择的训练图像及其描述,以了解模型训练的数据。
  1. 定义训练和评估循环
    • 实现随机梯度下降(SGD)训练循环以优化模型参数。计算每个示例的损失,从损失计算中排除前缀和填充的标记。
    • 实现评估循环,对验证数据集进行预测,处理小数据集的填充,并确保输出中只计算实际示例。

完成所有这些步骤后,现在我们可以对模型进行微调。下面的代码将实现相同的功能。它在模型上运行了64个步骤的训练循环,并在每个步骤显示学习率(lr)和损失率。每16个步骤,它输出相同一组图像的模型预测,让您观察模型预测描述能力的改善。在训练初期,预测可能包含重复或不完整的句子,但随着训练的进行,描述的准确性会提高。到第64步,模型的预测应该与训练数据中的描述非常接近。

BATCH_SIZE = 8
TRAIN_EXAMPLES = 512
LEARNING_RATE = 0.03
TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE
EVAL_STEPS = TRAIN_STEPS // 4
train_data_it = train_data_iterator()
sched_fn = big_vision.utils.create_learning_rate_schedule(
    total_steps=TRAIN_STEPS+1, base=LEARNING_RATE,
    decay_type="cosine", warmup_percent=0.10)
for step in range(1, TRAIN_STEPS+1):
  # Make list of N training examples.
  examples = [next(train_data_it) for _ in range(BATCH_SIZE)]
  # Convert list of examples into a dict of np.arrays and load onto devices.
  batch = jax.tree.map(lambda *x: np.stack(x), *examples)
  batch = big_vision.utils.reshard(batch, data_sharding)
  # Training step and report training loss
  learning_rate = sched_fn(step)
  params, loss = update_fn(params, batch, learning_rate)
  loss = jax.device_get(loss)
  print(f"step: {step:2d}/{TRAIN_STEPS:2d}   lr: {learning_rate:.5f}   loss: {loss:.4f}")
  if (step % EVAL_STEPS) == 0:
    print(f"Model predictions at step {step}")
    html_out = ""
    for image, caption in make_predictions(
        validation_data_iterator(), num_examples=4, batch_size=4):
      html_out += render_example(image, caption)
    display(HTML(html_out))

现在,您可以使用预定义的名为 make_predictions 的函数测试微调后的模型,该函数会迭代处理图像并对每个图像执行推断。此函数可用于测试我们微调后的目标检测模型。

print("模型预测")
html_out = ""
for image, caption in make_predictions(validation_data_iterator(), batch_size=4):
  html_out += render_example(image, caption)
display(HTML(html_out))

以下是每次迭代的模型输出样本。对于当前目的,微调是为了演示目的而进行的30步。数据集、步数和其他超参数也会根据您的使用和需求而改变。

步骤 5:保存微调后的模型

一旦微调完成并检查了模型预测,为了进一步使用相同的模型或将其部署到后续阶段,可以使用以下代码进行保存:

flat, _ = big_vision.utils.tree_flatten_with_names(params)
with open("/content/fine-tuned-PaliGemma-3b-pt-224.f16.npz", "wb") as f:
  np.savez(f, **{k: v for k, v in flat})

步骤 6:部署微调后的模型

对于部署,我们将依赖于 Roboflow 推理服务器,并将其部署到 AWS EC2 实例上。Roboflow 推理服务器允许您将计算机视觉模型部署到包括 AWS EC2 在内的各种设备上。推理服务器依赖于 Docker 运行。如果您尚未在要运行推理的设备上安装 Docker,请按照官方 Docker 安装说明进行安装。安装 Docker 后,运行以下命令在 AWS EC2 上下载 Roboflow 推理服务器。

pip install inference supervision

现在,Roboflow 推理服务器将运行,并且您可以在 EC2 服务器上使用微调后的模型。

结论

在本博客中,我们已经详细介绍了谷歌的尖端视觉语言模型 PaLiGemma 的微调和部署过程。从安装必要的依赖项并设置环境开始,我们利用了各种工具和平台,包括 Kaggle 用于访问模型权重,Roboflow 用于数据集准备,以及 Azure 虚拟机 用于部署。通过遵循这些步骤,您可以利用 PaLiGemma 的强大功能进行各种视觉语言任务,如目标检测、图像字幕和视觉问题回答。我希望这篇指南为您提供了一个清晰而实用的途径,以增强您的项目的先进人工智能能力。

主要收获
  • 整合先进模型: PaLiGemma 结合了 SigLIP 和 Gemma 的能力,提供了一个多语言、多任务的多功能轻量级视觉语言模型。
  • 增强的视觉语言能力: 与许多其他 VLM 不同,PaLiGemma 有效处理目标检测和分割,使其成为各种视觉语言活动的强大选择,包括文本阅读、视觉问题回答以及图像/视频字幕。
  • 逐步微调过程: 本教程提供了微调 PaLiGemma 的详细逐步指南,涵盖了设置依赖项、准备数据以及使用 JAX 配置模型权重等关键步骤。
  • 资源的高效利用: 本教程通过利用 Roboflow 进行数据集准备、Kaggle 获取模型权重以及 Azure 虚拟机进行部署等工具,展示了高效的资源管理和实用的部署策略。
  • 实际应用和部署: 本指南以在 EC2 服务器上部署微调模型为结尾,展示了如何将理论知识应用于实际情况,并使用户能够在现实场景中利用 PaLiGemma 的能力。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

数智笔记

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值