TowardsDataScience 2023 博客中文翻译(八十六)

原文:TowardsDataScience

协议:CC BY-NC-SA 4.0

创建你自己的生成 AI 文本到图像 API

原文:towardsdatascience.com/create-your-own-generative-ai-text-to-image-api-548c07a4d839

将你的随想转化为杰作,按需制作

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 奥默尔·马赫穆德

·发表于 Towards Data Science ·17 分钟阅读·2023 年 4 月 12 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者使用 Midjourney 根据一般商业条款生成。

TL;DR

  • 生成 AI 的最新进展导致了一系列服务的推出,如 DALL-E 2、Midjourney 和 Stability AI,它们有潜力彻底改变我们对内容创作的方式。

  • 在这篇文章中,我将展示如何通过 API 构建并提供你自己高性能的文本到图像服务。基于 稳定扩散 通过 HuggingFace,使用 Vertex AI Workbench 和 Endpoints。

🚣🏼 我们是如何到达这里的

正如乔治·劳顿在他的 文章 中提到的:“生成 AI 是一种人工智能技术,能够生成各种类型的内容,包括文本、图像、音频和合成数据。最近围绕生成 AI 的热议是由于新用户界面的简便性,这些界面可以在几秒钟内创建高质量的文本、图形和视频。”[2]

机器学习并不是什么新鲜事,事实上,自 1960 年代以来,它以某种形式存在。“但直到 2014 年,随着 生成对抗网络(GANs)的引入,一种机器学习算法,生成 AI 才得以创造令人信服的真实人物图像、视频和音频。”[2]

结合能够接收自然语言提示并生成照片级真实图像的大型语言模型(LLMs)的能力,我们在短时间内取得了巨大的进步。第一个做到这一点的是OpenAI 的 DALL·E,于 2022 年 4 月推出,随后是 2022 年 8 月的 Disco Diffusion,最终被稳定扩散所取代。与这些产品并行的是一家名为Midjourney的公司,它开发了一个非常受欢迎的模型,通过与 Discord 机器人互动使用。

从那时起,艺术状态的进展令人惊叹。下面的截图展示了仅在几个月的时间内所取得的成就——两个不同的模型提供了相同的提示,但早期模型和后期模型之间的对比非常明显!

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 1:生成性人工智能艺术的进展,由 Disco Diffusion 和 Midjourney 生成的图像。

快进到 2023 年,Midjourney 于 3 月 15 日刚刚发布了第 5 版。这个模型具有非常高的连贯性,擅长解释自然语言提示,分辨率更高,并支持像重复图案这样的高级功能[4]。Stability AI 也发布了稳定扩散 2.1。这些模型的发展速度非常迅猛!

正如德勤最近发布的白皮书中描述的:“我们感觉可能只是刚刚开始看到生成性人工智能模型所能带来的影响。虽然早期的吸引力主要来自消费者发布,这可能会成为定义时代的里程碑,但生成性人工智能也有潜力为几乎所有生活领域增加情境意识和类人决策能力。

因此,生成性人工智能吸引了来自传统(例如风险投资(VC),并购(M&A))和新兴(例如生态系统合作伙伴)来源的兴趣。仅在 2022 年,风险投资公司就投资了超过 20 亿美元,科技领袖也进行了重要投资,如微软对 OpenAI 的 100 亿美元投资和谷歌对 Anthropic 的 3 亿美元投资。”[3]

🧐 什么是稳定扩散?

稳定扩散是一种用于人工智能文本生成图像的方法,它利用扩散模型从文本描述中创建图像。扩散模型从一张随机图像开始,然后逐步向其中添加噪声。噪声以受控的方式添加,使得图像仍然可以辨认。

在每个扩散步骤中,图像变得更加精细,细节变得更加清晰。这个过程会持续几个扩散步骤,直到生成的图像被认为“稳定”,这意味着我们已经达到了一个进一步迭代不会改善图像的点。过程如下所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 2:Stable Diffusion 过程,图像来源: Benlisquare,根据许可共享: CC BY-SA 4.0

扩散模型在一个图像和文本描述的数据集上进行训练,它学会将文本描述与匹配的图像关联起来。当你给扩散模型一个新的文本描述时,它会利用其知识生成一个匹配描述的图像。

Stable Diffusion 的主要优势在于它非常快速。它可以在几秒钟内从文本描述生成图像。这比其他方法(如 GANs)快得多。我们的现代竞争者(例如 Midjourney、DALL E 等)都使用某种变体的 Stable Diffusion 技术。

但 Stable Diffusion 并不完美。虽然在过去一年中它的应用有了飞跃式的进步,但有时生成的图像可能会失真或缺乏细节。这可能是因为像这样的模型会根据训练数据来“猜测”图像应该是什么样子。随着训练数据集和模型算法的改进,这个问题将会减少。

👷🏾‍♀️ 开始吧!

在这篇文章中,我将向你展示如何使用一些代码来:

  1. 互动地实验 Stable Diffusion 模型,生成一些酷炫的艺术作品!

  2. 通过端点服务模型 使用 Vertex AI。

  3. 使用 FLASK 创建一个简单的 RESTful API

当然,你可以直接使用我们在前面章节中提到的现有的面向消费者的模型,但那样会没有乐趣啊?😜

➡️ 随意跳到你感兴趣的部分!

👾 我会把任何代码片段(如果没有链接的话)放在一个 github 仓库中,链接见最后的有用资源部分

👩🏻‍💻 1. 实验 Stable Diffusion

当你开始处理一个新的机器学习问题时,笔记本提供了一种非常灵活的方式来测试模型和快速迭代。也许你喜欢在本地环境中运行 Jupyter,使用 Kaggle Kernel,或者我个人最喜欢的 Colab。有了这些工具,创建和实验机器学习变得越来越可及。

我们将创建一个快速的笔记本,向你介绍 Stable Diffusion 模型

💨 如果你没有时间(或不想)一步步构建,可以查看资源部分下载笔记本文件!

前提条件:

  • 用于登录 Colab 和保存笔记本的 Google 账户

  • HuggingFace 账户和 API 令牌(免费)

  1. 访问 colab.research.google.com/。这将为你创建一个新的 Python 笔记本,并将其保存到你的 Google Drive 中。

  2. Colab 免费提供,但如果你需要更多的内存和计算资源来运行你的笔记本,你需要付费。我们将只使用免费的资源,首先你需要配置笔记本运行时以包含 GPU 加速器。使用 GPU 会加快模型生成图像的时间,也就是执行推理。免费层提供了 12GB 内存的 NVIDIA Tesla T4 GPU(这对我们要做的事情正好足够)。从文件菜单中选择 运行时 -> 更改运行时类型,然后选择硬件加速器 = GPU。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 3:Colab 笔记本,更改笔记本运行时。

然后你需要连接到你的运行时,以便我们可以执行笔记本中编写的任何代码。点击 连接/重新连接 -> 连接到托管运行时

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 4:Colab 笔记本,“+ 代码”按钮添加单元格,以及连接到托管运行时。

  1. 一旦你连接到你的运行时,就可以输入一些代码了!你可以在单个笔记本“单元格”中输入代码,或者如果你想单独执行步骤,可以使用“+ 代码”按钮创建一个新的“单元格”。我发现将代码分解成单元格很有用,这样更容易识别和调试问题。每个单元格在你悬停时都有一个“播放”按钮,点击它以运行代码。输入或复制粘贴代码到新的单元格中后,你可以运行以下步骤:

a. 由于我们需要从 HuggingFace 拉取模型,因此需要安装一些 Python 库并使用 API 令牌进行认证:

!pip install --upgrade huggingface_hub

b. 接下来,我们需要 使用你 HuggingFace 账户的令牌进行认证,当你执行以下代码时,你会在笔记本中被提示粘贴你的 HuggingFace API 令牌:

from huggingface_hub import notebook_login
notebook_login()

c. 成功登录你的 HuggingFace 账户后,我们将下载 diffusers 和 transformers Python 库:

!pip install -qq -U diffusers transformers

d. 我们需要创建一个 StableDiffusion 模型管道,以便我们可以将模型传递一些文本,并根据该提示生成图像。你可能会注意到我们传递的参数之一是 HuggingFace 上托管的 Stable Diffusion 模型的路径。此帖子中的示例使用的是 v1.5,你可以尝试在写作时交换为最新的(v2.1):

from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained('runwayml/stable-diffusion-v1–5')
.to('cuda')

e. 现在进入有趣的部分,我们将生成一些艺术作品!加载 torch Python 库,然后添加一个单元格,将 prompt 变量的字符串值替换为你想要的任何内容!我提供了一个简单的示例以帮助你入门:

import torch

# Initialize a prompt
prompt = “polar bear on an iceberg”

# Pass the prompt in the pipeline
pipe(prompt).images[0]

你应该得到如下所示的结果:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 5:Colab 笔记本,Stable Diffusion 模型生成的图像输出。

💡提示: 如果你遇到‘内存不足’的错误,你可以通过在文本提示上方添加以下代码行定期清除 GPU 的缓存:

torch.cuda.empty_cache()

只需在一个笔记本中写几行代码,你就可以从基于文本的提示生成艺术作品!

🍦 2. 通过 Vertex AI 端点服务模型

制作生产应用程序或训练大型模型需要额外的工具来帮助你超越笔记本中的代码,并且使用云服务提供商可以提供帮助。

我们的目标是将 Stable Diffusion 模型打包,并托管在一个可以处理来自应用代码的预测请求的端点上。

你可以使用其他公共云服务提供商来实现大致相同的结果,但我将使用 Google Cloud Platform (GCP),特别是它的 Vertex AI 工具集,因为我对此最为熟悉。

前提条件:

  • 启用计费的 GCP 账户/免费起始积分

  • 具备一些基本的 GCP 管理知识,例如如何创建项目、虚拟机、存储桶和其他资源,

  • 通过 IAM 给服务账户赋予权限,

  • 下载服务账户密钥,

  • 使用 Google Cloud SDK + 一些 Python 代码。

1. 我们将使用一个预构建的笔记本示例:github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/vertex_endpoints/torchserve/dreambooth_stablediffusion.ipynb,该示例来自 GCP Vertex AI 在 Github 上的官方仓库。我不会在这里详细复述整个内容,只会解释最相关的部分以及我在实际操作中遇到的问题……代码示例通常没有看起来那么简单!

⚠️ 在这个示例中创建用户管理的笔记本并将模型部署到端点将产生费用。当你第一次创建 GCP 账户时,可以获得 $300 的免费积分,但推荐的硬件配置(以及模型部署的端点的任何后续调用)可能会很快消耗这些积分——你已经被警告了。我将在“结语”部分分享我实际产生的费用。

2. 如果这是你第一次使用 Vertex AI,请登录到你的 GCP 控制台,并确保从 Vertex AI 仪表板启用所有必要的 API:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 6: GCP 中的 Vertex AI 仪表板,启用所有推荐的 API 按钮。

3. 打开 笔记本,在标题下方,你应该能看到一些按钮,包括“在 Colab 中运行”和“在 Vertex AI Workbench 中打开”。你需要点击后者。

4. 然后你需要配置一个 VM 实例来托管用户管理的笔记本。示例建议使用带有 85GB RAM 的 NVIDIA A100 实例,也就是“a2-highgpu-1g”。推理(从笔记本中的示例创建图像)很快,大约 11 秒钟就能从 Stable Diffusion 模型中返回图像。

5. 一旦你的实例创建完成,你可以通过 Workbench -> 用户管理的笔记本 访问它。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 7:Vertex AI 工作台,用户管理的笔记本列表。

在你深入并点击 OPEN JUPYTERLAB 之前,确保给你的笔记本所在 VM 的服务账户授予权限,以便进行诸如创建存储桶和端点等操作。由于它是“用户管理”的,这不会自动为你完成。

6. 点击笔记本名称,在笔记本详细信息中你将看到所有者或服务账户别名。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 8:Vertex AI,笔记本详细信息。

7. 转到 IAM & Admin -> IAM,然后点击 授予访问权限,粘贴或输入你的笔记本实例的服务账户别名,为了简化操作,我授予了“编辑者”访问权限。如果你很挑剔,你可以只授予 dreambooth_diffusion 示例步骤所需的特定权限。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 9:GCP IAM & Admin, IAM, 编辑访问/分配角色。

8. 现在你应该准备好在 Jupyterlab 中运行 dreambooth_diffusion.ipynb 笔记本代码了,从工作台中打开它(如步骤 5 所示)。由于某种原因,示例的 git 仓库中的代码没有被复制到我的笔记本实例中,所以我只是打开了终端并快速克隆了 GitHub 仓库:

$ git clone https://github.com/GoogleCloudPlatform/vertex-ai-samples

9. 一旦你在 Jupyterlab 中打开了 dreambooth_diffusion.ipynb,你应该能够运行笔记本单元而不会遇到重大问题。笔记本的第一部分是通过步骤下载 Stable Diffusion 模型并从提示中创建图像。下一步是创建 Vertex AI 端点并将模型部署到那里进行服务。

10. 按照笔记本中的步骤进行:

a. 创建自定义 TorchServe 处理程序。

b. 将模型工件上传到 Google Cloud Storage。

c. 使用模型工件和预构建的 PyTorch 容器镜像创建 Vertex AI 模型。

d. 将 Vertex AI 模型部署到端点上。

只要你在之前启用了 Vertex API(参见步骤 2。如果你忘记了!),这一切应该都会顺利进行。对我而言,在创建端点之后,模型部署大约花费了 30 分钟。当它准备好提供服务时,你会在 端点 下看到类似的内容:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 10:Vertex AI 端点。

你现在准备好从你的 Stable Diffusion 模型中处理请求了!!此时,如果你想节省成本,可以停止笔记本实例 VM 并删除你创建的存储桶。请参见 dreambooth_diffusion.ipynb 最后部分的“清理”步骤。

🧪 测试你的端点

要向 Vertex AI 端点发送请求,你需要使用支持发送具有适当 HTTP 方法和请求参数的请求的 HTTP 客户端库或命令行工具。

我在我的笔记本电脑上进行了本地测试。为此,你需要下载并安装 Vertex AI Python SDK,然后创建并下载一个服务密钥用于认证。

如果你保留了你的笔记本实例虚拟机,你可以使用之前的服务帐户别名,或者只需创建一个新的服务帐户,其权限至少包括从端点获取预测

1. 转到 IAM & Admin -> 服务帐户。然后点击服务帐户别名右侧操作下的三个点,然后点击 管理密钥

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 11:GCP IAM & Admin,服务帐户,管理服务帐户的密钥。

然后点击添加密钥 -> 创建新密钥,并按推荐的 .JSON 格式下载:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 12:GCP IAM & Admin,服务帐户,为服务帐户创建密钥。

⚠️ 请记住,服务帐户密钥文件授予与你的 GCP 项目中服务帐户相同的权限。 始终非常小心处理该文件,删除它时不要再使用,并且永远不要将其上传到 Github 仓库,我见过太多这种情况!

2. 在终端窗口中,或在你运行代码以测试端点的地方,你需要设置一个变量以指向你下载的服务密钥。这将对端点请求进行认证:

$ export GOOGLE_APPLICATION_CREDENTIALS=”path/to/your.json file

3. 现在只需用你的项目名称、区域和端点 ID 修改这个 Python 代码片段。它将向你的端点传递一个提示,并将响应(即由 Stable Diffusion 生成的图像)存储在 JPEG 文件中:

import base64
import logging
from io import BytesIO
from google.cloud import aiplatform as aip

PROJECT_NAME = “YOUR-PROJECT-ID”
REGION = “us-central1”
ENDPOINT_ID = “YOUR-ENDPOINT-ID”

aip.init(project=PROJECT_NAME, location=REGION)
endpoint = aip.Endpoint(endpoint_name=ENDPOINT_ID)
text_input = “””Polar bear on an iceberg”””

# Invoke the Vertex AI endpoint 
def query_endpoint(endpoint, text_input):
  payload = {“prompt”: text_input}
  response = endpoint.predict(instances=[payload])
  return response

response = query_endpoint(endpoint, text_input)

with open(“generated_imgage.jpg”, “wb”) as g:
    g.write(base64.b64decode(response.predictions[0]))

如果你安装了 Vertex AI Python SDK、进行了 GCP 认证,并且端点处于活动状态,几秒钟后你应该会看到生成的图像文件出现在你的文件系统中!

🎁 3. 使用 FLASK 创建一个简单的 RESTful API

此时,你可以很容易地将调用生成 AI 模型端点的代码集成到使用 Vertex AI Python SDK 的现有应用程序中。

但我确实承诺了一个 API,所以我将在这一部分的最后部分详细介绍它。

前提条件:

  • 安装 FlaskPillow Python 库。

  • 下载 Postman 并安装,它是免费的!我们将用它来模拟对我们的 API 的调用。

1. 与前一节一样,确保你有一个环境变量,指向你计划运行代码的地方,这个变量指向你的服务帐户密钥。

2. 这是你需要的代码,用于使用 Flask Web 应用程序框架创建一个简单的 RESTful API。请参见代码中的注释,了解正在发生的事情。我们基本上使用了之前的代码来查询端点,并将其封装在 API 调用中:

import base64
from google.cloud import aiplatform as aip
from flask import Flask, jsonify, request, send_file
from PIL import Image
from io import BytesIO

app = Flask(__name__)

@app.route(/predict’, methods=[‘POST’])
def predict():
  PROJECT_NAME = ‘YOUR-PROJECT-ID’
  REGION = ‘us-central1’
  ENDPOINT_ID = ‘YOUR-ENDPOINT-ID’

  # Get the input data from the HTTP request
  input_data = request.get_json()

  # Extract the text parameter from the input data
  prompt = input_data.get(‘prompt’, ‘’)

  aip.init(project=PROJECT_NAME, location=REGION)
  endpoint = aip.Endpoint(endpoint_name=ENDPOINT_ID)
  text_input = prompt

  # Invoke the Vertex AI
  payload = {“prompt”: text_input}
  response = endpoint.predict(instances=[payload])

  # Decode the image data from base64 format
  image_data = response.predictions[0]
  image_bytes = base64.b64decode(image_data)

  # Create a PIL Image object from the decoded image data
  image = Image.open(BytesIO(image_bytes))

  # Save the image to a BytesIO buffer
  buffer = BytesIO()
  image.save(buffer, format=’JPEG’)
  buffer.seek(0)

  # Return the image file in the response
  return send_file(buffer, mimetype=’image/jpeg’)

if __name__ == ‘__main__’:
    app.run(debug=True)

3. 运行生成的文件,你应该会有一个托管你 API 的 Flask 服务器,准备接收请求:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 13:MacOS 终端,本地运行 Flask 服务器。

4. 接下来我们将启动 Postman,进入 文件 -> 导入,粘贴以下 cURL 命令(如果你的服务器地址配置与默认不同,你可能需要修改它):

curl -X POST \
-H “Content-Type: application/json” \
-d ‘{“prompt”: “Astronauts in the ocean”}’ \
http://127.0.0.1:5000/predict \
-- output generated_image.jpg

我们可以从终端运行这个,API 的响应会被保存为本地文件系统中的 JPEG 图像。但为了模拟一个应用程序,而且因为我没有时间编写 Discord 机器人或 HTML 前端… 😅

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 14:Postman,API 请求示例,Stable Diffusion 模型生成的图像响应。

如果一切顺利,你将得到一个漂亮生成的图像,通过你自己的 API 提供服务。你可以在“Body”标签下修改提示,创造力是唯一的限制!

🗑️ 实验完成了吗? 你可以返回到 Vertex AI -> 端点,选择端点,撤销模型部署,返回上一层删除它(提示:每行末尾的三点菜单)。仔细检查你是否关闭或删除了与这篇文章相关的任何其他内容,以避免消耗资源。你也可以撤销你创建的服务账户密钥,以确保安全。

⚠️ 最后警告 — 这显然不是生产就绪的代码。如果你打算让你的 API 对外公开,还有很多工作需要做,比如认证等,这些超出了本文的范围!相对而言,我们的端点每次请求的成本(或计算时间)也比 Midjourney 或 DALL-E 2 高得多,因此可能不适合作为服务上线。

🏆 总结和结束思考

我在写这篇文章时非常开心,并且学到了很多关于生成式 AI、Stable Diffusion 以及将其打包成类似今天流行的消费者服务的内容。我向那些站在这项技术最前沿的开发团队致敬,这确实是一个令人兴奋的领域!

💸 成本

我有点担心运行大型 GPU 附加实例。如果你是长期使用公共云的用户,你会知道没有适当的检查和制衡机制,很容易超支。

在写完这篇文章的最后,我查看了 GCP 控制台中的账单部分,看看花了多少钱…

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 15:GCP 账单,报告。

在使用 GCP 的两天里,专门为这篇文章创建的项目中,我花费了大约 40 美元。你可以在上面的截图中看到详细分解。

大部分费用与推理相关(参见 Vertex AI 服务条目)。这是使我的端点能够使用我部署的 Stable Diffusion 模型生成图像的“计算小时”费用。

端点使用了“Vertex AI:在线/批处理预测 NVIDIA Tesla P100 GPU 运行于美洲的 AI 平台 SKU”,还包括运行预定义实例的一些其他费用。我们本可以通过选择更便宜的 GPU 家族来减少开支,但这将导致生成图像的时间延长。

与商业上可用的解决方案相比,Midjourney 的 基础计划 每月订阅费用为 10 美元。这个计划包括每月 3.3 小时的快速 GPU 计算,超出部分费用为每小时 4 美元。在我们的案例中,我让端点运行了大约 24 小时,花费了 30 美元,没有并发或作业等待的限制。再次说明,与 Midjourney 这样的完全托管服务相比,工程师们始终在改进模型,而选择在可扩展的云基础设施上迭代和部署自己的模型之间存在权衡。

这里需要注意的重要一点是,你支付的不是每次预测请求的费用,而是你的端点运行的时间以及它所运行的实例的大小/类型。

测试模型和通过用户管理的笔记本实际部署端点的费用不到 10 美元(参见计算引擎和笔记本条目)。

所以,本文简要概述了生成式人工智能,特别是使用 Stable Diffusion 技术创建艺术的应用。接着,我们深入探讨了一些代码示例,以展示使用 Stable Diffusion 模型生成艺术的步骤,以及如何将其部署到端点并通过 API 使用。希望你喜欢这篇文章,我们下次见! 👋🏼

本文中的数据由作者创建,除非另有说明。

📇 参考资料

[1] 机器学习,历史及其与其他领域的关系: en.wikipedia.org/wiki/Machine_learning#History_and_relationships_to_other_fields

[2] 什么是生成式人工智能?你需要知道的一切,George Lawton: www.techtarget.com/searchenterpriseai/definition/generative-AI

[3] 生成式人工智能对企业的影响——人工智能的新前沿,Deloitte LLP: www2.deloitte.com/us/en/pages/consulting/articles/generative-artificial-intelligence.html

[4] Midjourney 文档,版本: docs.midjourney.com/docs/model-versions#:~:text=Current%20Model,places%2C%20objects%2C%20and%20more

📚 有用资源

在 SageMaker Studio 中创建你自己的大语言模型实验室

原文:towardsdatascience.com/create-your-own-large-language-model-playground-in-sagemaker-studio-1be5846c5089?source=collection_archive---------8-----------------------#2023-03-20

现在你可以在一个地方部署和实验大语言模型

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Heiko Hotz

·

关注 发表在 Towards Data Science ·4 min read·2023 年 3 月 20 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供 — 使用 Midjourney 创建

这是什么内容?

通过 REST 端点利用大语言模型(LLMs)具有众多优点,但通过 API 调用进行实验可能会很麻烦。以下我们将看到如何与已部署到 Amazon SageMaker 端点的模型进行交互。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供

为了简化这个过程,开发一个允许与已部署模型无缝互动的游乐场应用将是有利的。在本教程中,我们将通过使用 Amazon SageMaker(SM)Studio 作为我们的全功能 IDE,并将 Flan-T5-XXL 模型部署到 SageMaker 端点,随后创建一个基于Streamlit的游乐场应用,直接在 Studio 中访问。

本教程的所有代码都可以在这个GitHub 仓库中找到。

为什么这很重要?

评估和对比不同的 LLM 对组织来说至关重要,以确定最适合其独特需求的模型,并快速进行实验。一个游乐场应用提供了最便捷、快速和简单的方法,让利益相关者(无论是技术人员还是非技术人员)可以实验已部署的模型。

此外,利用游乐场应用可以增强对比,并促进进一步的定制,例如加入反馈按钮和对模型输出进行排名。这些附加功能使用户能够提供反馈,提升模型的精确性和整体性能。实质上,游乐场应用提供了对模型优势和劣势的更深入理解,最终帮助做出明智的决定,以选择最适合应用的 LLM。

让我们开始吧!

部署 Flan-T5-XXL 模型

在我们可以设置游乐场之前,我们需要设置一个 REST API 来访问我们的模型。幸运的是,在 SageMaker 中这非常简单。类似于我们部署 Flan-UL2 模型时所做的那样,我们可以编写一个推理脚本,从Hugging Face Model Hub下载模型,并将其部署到 SageMaker 端点。这个端点随后为我们提供一个 REST API,我们可以在 AWS 账户内访问,而不必使用 API Gateway。

请注意,我们使用了 8 位加载模型的选项,这使我们能够将模型部署到单个 GPU(G5 实例)上。

一旦我们准备好推理脚本,就可以通过一个命令部署模型:

欲了解更多详细信息,请查看部署笔记本和我之前的关于部署 Flan-UL2 的博客文章

一旦端点启动并运行,我们就可以开始有趣的部分——设置一个游乐场应用以与模型互动。

游乐场应用

我们将使用 Streamlit 开发一个精简的游乐场应用。只需几行代码,它就能让我们创建一个文本框,并在用户友好的界面中展示各种生成参数。欢迎您修改应用,并展示一组不同的生成参数,以便更好地控制文本生成过程。

所有生成参数的列表可以在这里找到。

请注意,你需要在第 10 行指定终端名称,你可以从 SageMaker 控制台的部署笔记本中获取。

测试

现在是时候部署和测试我们的实验平台应用程序了。受TensorBoard 在 SM Studio 中的使用说明的启发,我们可以使用相同的机制在 SM Studio 中启动我们的 Streamlit 应用。

为此,我们可以在终端执行命令streamlit run flan-t5-playground.py --server.port 6006。之后,我们将能够通过https://<YOUR_STUDIO_ID>.studio.<YOUR_REGION>.sagemaker.aws/jupyter/default/proxy/6006/访问这个实验平台。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

结论

在本教程中,我们成功部署了一个前沿语言模型,并在单一环境 SageMaker Studio 中建立了一个实验平台。启动 LLM 实验的过程从未如此简单。希望你觉得这些信息有价值,如果你有任何问题或需要进一步的帮助,请随时联系我。

Heiko Hotz

👋 关注我在MediumLinkedIn上,阅读更多关于生成 AI、机器学习和自然语言处理的内容。

👥 如果你在伦敦,欢迎加入我们的NLP London Meetups

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

www.linkedin.com/in/heikohotz/

在 AWS 上快速创建你自己的稳定扩散 UI

原文:towardsdatascience.com/create-your-own-stable-diffusion-ui-on-aws-in-minutes-35480dfcde6a?source=collection_archive---------0-----------------------#2023-01-03

使用一个命令部署文本到图像的 web 应用

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Heiko Hotz

·

关注 发表在 Towards Data Science ·8 分钟阅读·2023 年 1 月 3 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供 — 使用稳定扩散创建

这是什么内容?

稳定扩散(SD)在 2022 年迅速成为最受欢迎的文本生成图像(即“AI 艺术生成”)模型之一。成功的一个关键因素是它被作为开源软件发布。这促使一个充满活力的社区迅速建立工具,使 SD 对任何对其感兴趣的人更加可及,无论其技术知识如何。

这些工具之一是简单而强大的 Web 界面 stable-diffusion-webui by Automatic1111。它允许我们在无需编程的情况下使用 SD 的所有功能,并且它还是开源的,这意味着任何人都可以下载这个 Web UI 以及 SD 模型,并在任何他们想要的地方进行部署。然而,挑战在于 SD 仍然需要 GPU 功率来运行,否则我们必须等待几分钟才能生成一张图像。而且我们许多人并不拥有足够强大的 GPU 来运行该模型。

多亏了云计算,我们无需花费巨资购买 GPU,而是可以“租用”一个。因此,在本教程中,我们将部署 Automatic1111 Web UI 到配备足够强大 GPU 以运行稳定扩散的 AWS EC2 实例上。我们将通过一个命令使用 AWS CloudFormation 模板来设置所需的所有基础设施。

与往常一样,你可以在我的 GitHub 账户中找到本教程的 代码

为什么这很重要?

拥有稳定扩散模型甚至是 Automatic 的 Web UI 作为开源工具,是民主化先进 AI 工具访问的重要一步。但这还不够,因为运行这些模型所需的 GPU 对大多数消费者来说依然昂贵。运行这些 AI 模型所需的 GPU 价格很容易超过 $2,000。

本教程展示了如何以每小时仅 $0.53 的价格入门,这就是 AWS 上 g4dn.xlarge 实例的按需价格。它允许我们使用带有 16 GB VRAM 的 NVIDIA T4 GPU。这意味着我们可以运行应用程序几个小时以进行试用,并生成我们想要的图像,然后关闭 EC2 实例,不需要支付超过实际使用时间的费用。而且这一切只需点击一个按钮,无需编程或 Linux 经验,这得益于 AWS CloudFormation 模板。

在开始之前需要说明一点:由于我在 AWS 工作,我显然对 AWS 有偏见。但我希望你从本教程中获得的核心信息是,通过云计算,先进的 AI 比以往任何时候都更具可及性和实惠性,无论你最终选择哪个服务提供商。

先决条件

跟随本教程,我们需要一个 AWS 账户,这几乎是唯一的先决条件。在本地机器上安装 AWS 命令行界面 (CLI) 会使事情变得更简单,但我也会演示如何在没有 CLI 的情况下仅使用 AWS 控制台来跟随教程。

快速入门指南

如引言中提到的,我们将使用 CloudFormation (CF) 模板通过一个命令来设置所有内容。启动模板后,应用程序需要 15–20 分钟才会准备好,所以这里有个提示:现在就启动模板,然后返回到这篇博客文章中,深入了解实际发生的背景情况 😉

在 AWS CLI 中启动应用程序

复制 CF 模板(或者,克隆整个 仓库)并在你的本地机器上运行下面的命令。这将会在你的 AWS 账户中创建一个名为“sd-webui-stack”的 CF 堆栈。

在 AWS 控制台中启动应用程序

在 AWS 控制台中,导航到 CloudFormation 部分,选择“创建堆栈 -> 使用新资源”:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

在接下来的对话框中,选择“模板已准备好”和“上传模板文件”:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

从仓库中选择 CF 模板文件进行上传,命名堆栈为“sd-webui-stack”,在接下来的对话框中保持默认设置,然后在最后一个对话框中点击“提交”。这将会在你的 AWS 账户中创建一个包含所有所需资源的堆栈。

一探究竟

现在我们已经启动了 CF 模板,可以揭开帷幕,深入了解实际发生的背景情况。

CloudFormation 模板

首先让我们来看一下 CF 模板:

这个模板设置了一些应用程序所需的资源。首先,我们创建一个“安全组”来指定 EC2 实例上哪些端口将会开放。我们选择端口 22,因为我们希望能够通过 SSH 连接到实例,同时选择端口 7860,因为我们的应用将在该端口监听。

接下来我们设置将托管应用程序的 EC2 实例。在这个模板中,我选择了 Ubuntu Server 22.04 LTS 发行版(AMI ID ami-0574da719dca65348),仅仅因为这是我最熟悉的。你可以更改为其他发行版,但请注意你需要相应地修改设置脚本(详见下文)。我们选择了一个 g4dn.xlarge 实例,如上所述。我们还配置了 300 GB 的磁盘空间,以确保有足够的空间来托管多个不同的模型。最后,我们在 EC2 实例上运行设置脚本,我们将在下一节中讨论。

接下来我们创建一个弹性 IP 地址,并将其分配给我们的 EC2 实例。这使我们能够拥有一个永久的 IP 地址,因此即使我们关闭 EC2 实例并在以后重新启动,它的应用程序也将始终托管在同一个 IP 地址上。

设置脚本

如前所述,我们在 EC2 实例上运行一个 setup script,该脚本将执行一些命令以为 Web UI 设置一切。我们一步一步来看。

第一部分禁用 Ubuntu 安装包后的重启对话框,然后安装我们需要的一些包:

下一部分下载并安装 CUDA 驱动程序,以便我们可以访问机器的 GPU:

之后,我们需要安装 Git Large File Storage,因为我们将下载一个大约 5 GB 的 Stable Diffusion 模型:

现在 Git LFS 已安装,我们可以从 Hugging Face Model Hub 下载模型。请注意,我们启用了“skip-smudge”选项,这允许我们仅下载所需的特定文件。在本教程中,我们下载了 SD v2.1(512px 版本),并将其移动到 Web UI 期望模型所在的目录中。

请注意,你可以更改脚本以下载不同版本的 Stable Diffusion,例如版本 1.5。你还可以在后续阶段通过将模型放入模型目录中,向 UI 添加任意多的模型。

除了模型之外,我们还需要一个由 WebUI 读取的配置文件。我们从 Stable Diffusion Github repo 下载一个配置文件,并将其重命名为与模型名称匹配,然后也放入相同的目录中:

最后,我们将 WebUI 的所有权更改为用户 ubuntu,并以该用户身份启动服务器(因为用户 root 不允许启动应用程序):

测试 Web UI

在 15–20 分钟后,部署应该完成。我们可以通过运行以下命令获取 EC2 实例的 IP 地址:

我们也可以通过访问 AWS 控制台中的 EC2 面板来检索它:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供

一旦我们获取了 IP 地址,就可以通过在浏览器中导航到 :7860 打开应用程序(如果请求超时则说明安装尚未完成)。安装完成后,我们可以看到应用程序已经启动运行 🎉

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

关闭 EC2 实例并重新启动应用程序

即使每小时仅需 $0.53,我们显然也不希望在不使用时运行实例。我们可以在 AWS 控制台中停止实例,并在需要时重新启动它,而不会丢失任何已安装的应用程序。一旦我们重新启动了 EC2 实例,就可以通过 SSH 登录并使用以下命令重新启动应用程序:

删除所有内容

如果我们想要删除所有创建的资源(即安全组、EC2 实例、弹性 IP),我们可以使用以下命令删除 CF 堆栈(或在 AWS 控制台中删除它):

限制

我想强调的是,这个教程只是一个起点,适合任何想尝试通过 Web UI 使用 Stable Diffusion 的人。这个方法有几个限制,如果想在生产环境中使用这个应用程序,可能需要解决这些限制。特别是,我没有涉及任何安全问题(请注意该应用程序运行在 http 上)、扩展问题(如果该应用程序需要同时服务多个用户)以及其他许多方面。

如果我们想在生产环境中使用这个应用程序,可以使用AWS Well-Architected Framework作为起点。

结论

在本教程中,我们利用 CF 模板通过一个命令设置了一个用于 Stable Diffusion 的 Web UI。这使我们能够访问最先进的 AI 模型,而无需自己购买昂贵的硬件。

这个应用程序的待办事项列表中有很多项目,我将它们列在了这里。对这个仓库的任何贡献都非常欢迎☺️

Heiko Hotz

👋 关注我在MediumLinkedIn上的文章,了解更多关于生成 AI、机器学习和自然语言处理的内容。

👥 如果你在伦敦,可以加入我们的NLP London Meetups

🤓 如果你对我如何可能帮助你在组织中采用 AI 和机器学习感兴趣,可以通过aiml.consulting与我联系。

几分钟内免费创建你自己的惊艳网站

原文:towardsdatascience.com/create-your-own-stunning-website-in-minutes-for-free-63f0f7c75bf

无需先前的网页开发经验

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Zolzaya Luvsandorj

·发表于 Towards Data Science ·阅读时间 8 分钟·2023 年 7 月 17 日

能够构建简单的网站带来了许多好处。也许你可以将简历发布到网站上以脱颖而出,或者创建自己的博客网站。可能性是无限的。有一种简单、快速且最重要的是免费的方式来托管静态网站,而无需了解之前的网页经验,只需利用预构建的主题。在这篇文章中,我将展示如何做到这一点。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:Brad Neathery 通过 Unsplash

📍 1. 概述

构建简单、快速且免费的静态网站有两个关键要素:GitHub Pages 用于免费托管静态网站,和 Start Bootstrap’s 免费惊艳主题,可以在几分钟内创建美丽的网站。你只需要一些 git 和 GitHub 的经验以及一个 GitHub 帐户。如果你需要复习 git 的基础知识,这个教程 可能对你有用。

📌 1.1. GitHub Pages

GitHub Pages 允许将公共 GitHub 仓库免费发布到网站上。根据 GitHub 仓库的命名方式,使用 GitHub Pages 托管的网站可以分为两种类型:

  • 用户网站: 每个 GitHub 用户可以拥有一个用户网站。GitHub 仓库应命名为 <username>.github.io,并可以通过 https://<username>.github.io/ 进行访问。

  • 项目网站: 每个 GitHub 用户可以拥有多个项目网站。GitHub 仓库可以命名为除 <username>.github.io 之外的任何名称,并且可以通过 https://<username>.github.io/<repository>/ 进行访问。

在这篇文章中,我们将一起构建一个项目网站作为实际示例。由于用户网站和项目网站之间唯一的区别是代码库名称和网页 URL,一旦你知道如何构建项目网站,构建用户网站就非常简单:你需要做的只是使用正确的代码库名称。

📌 1.2. Start Bootstrap

Start Bootstrap 提供免费的开源 Bootstrap 代码,用于美丽的示例网站。即使没有网页开发经验,我们也可以通过利用 Start Bootstrap 免费主题 背后的源代码,在几分钟内创建一个炫目的网站。

使用预建主题也意味着我们的选择将被限制在这些主题内。虽然我们可以通过这些主题进行快速建站并进行简单的自定义,但值得一提的是,越多的自定义需求就需要更多的网页开发经验和时间。在我们的示例中,我们将选择这些预建主题中的一个并稍作自定义。像用你的内容替换示例内容这样的简单变化,就像在任何软件中替换文本一样简单。

现在,让我们动手实践,边做边学吧!

📍2. 步骤指南

这里是建立网站的三个步骤:

1️⃣ 选择一个免费的 Start Bootstrap 主题,

2️⃣ 自定义它

3️⃣ 推送到 GitHub

听起来很简单?也许吧,因为确实如此。我们现在将详细了解每个步骤。

📌 2.1. 步骤 1:选择一个免费的 Start Bootstrap 主题并克隆代码

首先,我们需要从 Start Bootstrap 选择一个免费的主题。要查看它们:

◾️ 选择顶部面板上的主题,然后

◾️ 选择浏览所有主题或挑选一个所需的主题类别(例如:作品集与简历、博客),然后

◾️ 取消勾选价格下的Pro选项,只查看免费模板。

如果你点击任何主题的启动实时预览,你将看到样本网站的外观。一旦你找到想要使用的主题,点击查看源代码以查看 GitHub 上的代码库。代码库中有几个分支。我们需要克隆gh-pages(用于 GitHub Pages)分支上的代码。

让我们一起完成这个任务。我们将从挑选一个 清爽博客 主题开始,它位于博客类别下。我们将在终端中输入以下代码来克隆 代码库 的 GitHub Pages 分支:

# Clone gh-pages branch
git clone --branch gh-pages https://github.com/StartBootstrap/startbootstrap-clean-blog.git
# Go into the newly cloned directory
cd startbootstrap-clean-blog

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由作者提供 | 使用的快捷键:Tab 自动完成

如果你愿意,你可以将克隆的仓库重命名为你喜欢的名字。然而,我们在示例中不会进行重命名。

📌 2.2. 第 2 步:自定义代码

现在,我们将对代码进行以下三个小示例更改。

◾️ 更改首页上的图像

◾️ 替换关于页面中的文字

◾️ 从联系页面中删除部分内容

我们会确保每次更改都单独提交,以便提交历史记录清晰。

在自定义代码时,我们将熟悉仓库的文件夹结构。这个结构对于一般的网站非常常见。

📌 2.2.1. 更改首页上的图像 网站使用的图像都保存在assets/img文件夹中。

startbootstrap-clean-blog
└───assets
└───└───img
│   │   │   about-bg.jpg
│   │   │   contact-bg.jpg
│   │   │   home-bg.jpg
│   │   │   post-bg.jpg
│   │   │   post-sample-image.jpg
│   *

如果我们打开home-bg.jpg,我们会发现它是首页顶部使用的图像。我们也可以从它直观的名字猜测,它是页面的背景图像。现在,我们将从 Unsplash 这个提供免费下载美丽高分辨率图像的网站中找一个替代图像。我们选择了 这张图像,它是通过关键词‘planet’找到的。我们将这张新图像命名为home-bg.jpg并用它替换旧图像。完成后,以下是 提交 更改的 git 命令:

# Check status
git status
# Stage new image
git add assets/img/home-bg.jpg
# Check status
git status
# Commit staged change
git commit -m "Change background image on home page"
# Check status
git status

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像 | 使用的快捷键:Ctrl + L 清屏

2.2.2. 替换关于页面中的文字 现在,让我们练习用自己的内容替换一些内容。如果我们查看仓库结构,会看到以下 HTML 文件。

startbootstrap-clean-blog
│   about.html
│   contact.html
│   index.html
│   post.html
└───*

如果你双击about.html,它会在默认浏览器中打开。我们可以看到‘about’网页。现在,右键单击文件,用文本编辑器打开它。这是页面内容的源代码。你可以很容易地在源代码中找到网页上显示的文字。尝试更改源代码中看到的一些文字并保存,然后刷新浏览器中显示的文件。你将看到更改。这是感知你所做更改对网页影响的好方法,在提交之前你可以随意尝试文件,因为你总是可以用 Git 命令如git restoregit checkout恢复到最后一次提交的状态。

现在,我们将把“这是我做的事情”替换为“终身成长心态”,并将正文中的三个段落(第 55 至 57 行)替换为“你好,我是 Zolzaya。我喜欢数据科学。”类似地,你可以在任何页面上用自己的内容替换现有内容。然后,我们将 提交 更改:

# Check status
git status
# See changes in about.html
git diff about.html
# Stage new content on about page
git add about.html
# Check status
git status
# Commit staged change
git commit -m "Update content on about page"
# Check status
git status

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图像 | 使用的快捷键:Q 退出 less

2.2.3. 从联系页面中删除部分

如果您查看联系页面,您会看到有一个需要填写的表单。假设我们要删除这个表单。与之前一样,我们将同时在浏览器和文本编辑器中打开 HTML 文件。

HTML 文档由 HTML 标签组成。每个 HTML 标签通常都有一个开标签:<tag>和闭标签:</tag>,内容放在这些标签之间。因此,在删除这些标签时,确保删除整个标签部分。例如,如果您想删除以<div>…开头的部分,请确保将内容一直删除到相应的结束</div>标签。

您将看到表单位于第 56–107 行。我们将删除表单,然后将其上方的段落替换为“想要联系我?请随时通过 test@test.com 发邮件给我。”然后,我们将提交更改:

# Check status
git status
# See changes in about.html
git diff contact.html
# Stage new content on contact page
git add contact.html
# Check status
git status
# Commit staged change
git commit -m "Remove form from contact page"
# Check status
git status

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

我们已经完成了示例自定义,因此我们准备将更改推送到 GitHub。

📌 2.3. 第三步:推送到 GitHub,如有必要,请在 GitHub 中进行配置

让我们在 GitHub 上创建一个与您的本地仓库名称匹配的新公共仓库。在我们的示例中,它将是[startbootstrap-clean-blog](https://github.com/zluvsand/startbootstrap-clean-blog)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

仓库必须是公开的才能免费托管。

现在,我们将更新远程仓库的 URL,以便它指向新创建的仓库。然后我们将代码推送到远程仓库:

# Check remote
git remote -v
# Change remote
git remote set-url origin https://github.com/zluvsand/startbootstrap-clean-blog.git
# Check remote
git remote -v
# Push code to new remote
git push

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片 | 使用的快捷键:向上箭头访问之前的命令

现在,打开 GitHub 上的仓库。转到设置选项卡,然后从左侧面板中选择Pages

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片

您的网站将在推送代码后几分钟内从gh-pages分支中的代码自动构建。如果没有发生这种情况,或者您可能将分支命名为其他名称,则可以在 GitHub Pages 部分的分支子部分中指定。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片 | 突出显示的分支子部分

我们的示例仓库可以在这里找到,网站托管在这里

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片 | 来源:zluvsand.github.io/startbootstrap-clean-blog/

完成了,这就是全部内容!我们刚刚使用 GitHub Pages 和 Startbootstrap 主题一起构建了一个网站!显然,关于网页开发还有很多可以学习的内容。了解一些 HTML 和 CSS 会非常有用。如果你想学习 HTML 和 CSS 的基础知识,这个是 Udacity 提供的一个优秀的免费课程。如果你希望深入了解 GitHub Pages,可以查看GitHub Pages 文档

感谢阅读我的文章。希望你获得了有关如何利用免费资源构建网站的有用知识。如果你感兴趣,以下是一些你可能也喜欢的帖子链接:

◼️ 数据科学中的 Git 入门

◼️ 通过这些技巧丰富你的 GitHub 个人资料

再见啦 🏃💨

使用 LLaVA 创建你的视觉聊天助手

原文:towardsdatascience.com/create-your-vision-chat-assistant-with-llava-610b02c3283e?source=collection_archive---------3-----------------------#2023-11-11

开始使用开源的 LLaVA 模型来创建多模态对话模型。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Gabriele Sgroi, PhD

·

关注 发布在Towards Data Science ·17 分钟阅读·2023 年 11 月 11 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由Izabela KrausUnsplash上拍摄。

简介

大语言模型已经证明自己是一项革命性的技术。利用其能力的众多应用已经被开发出来,预计还会有更多的应用出现。大语言模型的一个最有趣的应用是将其作为智能助手,能够在各种任务中帮助用户。经过指令调优和从人类反馈中学习的聊天模型显示出很有前景的能力,能够遵循人类指令并完成指定任务。然而,它们在仅语言任务的适用性上存在局限。

多模态对话模型旨在释放大语言模型的潜力,以解决需要将自然语言与其他模态结合的问题。特别是,自从将视觉能力引入 GPT-4V 以来,视觉语言模型受到了越来越多的关注。赋予 GPT-4 的自然语言能力以图像理解,已导致一个强大的聊天助手,可以帮助用户处理需要视觉和语言理解的任务。尽管 GPT-4V 的视觉能力令人印象深刻,但闭源模型限制了对这一令人惊叹的技术进行研究和实验的潜力。幸运的是,一些开源模型出现了,将视觉语言模型的力量以易于访问和透明的方式带给了社区。这些模型还延续了对计算和内存效率的关注趋势,这一趋势在开源大语言模型中已经显现。这是一个重要的特性,因为它促进了这些模型的广泛应用。

在本教程中,我将介绍如何使用在Visual Instruction Tuning论文中介绍的 LLaVA(大语言与视觉助手)模型创建一个视觉聊天助手。我将首先简要介绍 LLaVA 模型及其改进,然后讨论使用官方代码库中提供的代码实现一个简单的视觉聊天助手。我还将展示一些我设计的示例,以展示该模型的能力和局限性。

LLaVA

LLaVA 模型首次在论文视觉指令调整中提出,随后在改进的基线与视觉指令调整中进一步改进(也称为 LLaVA-1.5)。其背后的理念是从图像中提取视觉嵌入,并像处理语言标记生成答案一样处理它们,通过将它们馈送给大型语言模型。直观地说,我们可以认为图像将用语言模型生成答案所需的“单词”。为了选择正确的“单词”,模型使用预训练的 CLIP 视觉编码器提取视觉嵌入,然后将其投影到语言模型的单词嵌入空间中。后者的操作是通过一个视觉语言连接器完成的,最初在第一篇论文视觉指令调整中选择为简单的线性层,后来在改进的基线与视觉指令中替换为更具表现力的多层感知器(MLP)。模型的架构如下所示。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

LLaVA 模型的架构。在 LLaVA 中,投影 W 是一个简单的线性层,而在 LLaVA-1.5 中是一个 MLP。图片来源于论文视觉指令调整

该方法的优势之一在于通过使用预训练的视觉编码器和预训练语言模型,只需学习轻量级模块——视觉语言连接器,从头开始。具体而言,LLaVA 的训练包括两个阶段:

  • 特征对齐的预训练:冻结预训练的视觉编码器和语言模型的权重,仅更新视觉语言连接器的权重。所有训练样本由文本-图像对打包成单回合对话。此阶段旨在训练视觉语言连接器将视觉编码器的嵌入与语言模型的文本嵌入对齐。

  • 使用视觉指令进行微调:在此阶段,仅冻结视觉编码器的权重,同时对视觉语言连接器和语言模型进行微调。模型在基于图像的指令跟随任务上进行微调。有趣的是,一些数据是通过使用仅包含语言的 GPT4 从图像的标题和实体边界框的坐标创建指令跟随样本。

视觉聊天机器人实现

使用官方存储库中提供的代码创建视觉聊天机器人非常简单。该存储库还提供了标准化的聊天模板,可用于解析正确格式的输入。遵循训练中使用的正确格式对生成的答案质量至关重要。确切的模板取决于使用的语言模型。使用预训练的 Vicuna 语言模型的 LLaVA-1.5 的模板如下:

A chat between a curious user and an artificial intelligence assistant. The 
assistant gives helpful, detailed, and polite answers to the user's questions. 

USER: <im_start><image><im_end> User's prompt

ASSISTANT: Assistant answer

USER: Another prompt

前几行是模型使用的一般系统提示。特殊标记<im_start>、和<im_end>用于指示将放置表示图像的嵌入的位置。

聊天机器人可以在一个简单的 Python 类中定义。

class LLaVAChatBot:
    def __init__(self,
                 model_path: str = 'liuhaotian/llava-v1.5-7b',
                 device_map: str = 'auto',
                 load_in_8_bit: bool = True,
                 **quant_kwargs) -> None:
        self.model = None
        self.tokenizer = None
        self.image_processor = None
        self.conv = None
        self.conv_img = None
        self.img_tensor = None
        self.roles = None
        self.stop_key = None
        self.load_models(model_path,
                         device_map=device_map,
                         load_in_8_bit=load_in_8_bit,
                         **quant_kwargs)

    def load_models(self, model_path: str,
                    device_map: str,
                    load_in_8_bit: bool,
                    **quant_kwargs) -> None:
        """Load the model, processor and tokenizer."""
        quant_cfg = BitsAndBytesConfig(**quant_kwargs)
        self.model = LlavaLlamaForCausalLM.from_pretrained(model_path,
                                                           low_cpu_mem_usage=True,
                                                           device_map=device_map,
                                                           load_in_8bit=load_in_8_bit,
                                                           quantization_config=quant_cfg)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path,
                                                       use_fast=False)
        vision_tower = self.model.get_vision_tower()
        vision_tower.load_model()
        vision_tower.to(device='cuda')
        self.image_processor = vision_tower.image_processor
        disable_torch_init()

    def setup_image(self, img_path: str) -> None:
        """Load and process the image."""
        if img_path.startswith('http') or img_path.startswith('https'):
            response = requests.get(img_path)
            self.conv_img = Image.open(BytesIO(response.content)).convert('RGB')
        else:
            self.conv_img = Image.open(img_path).convert('RGB')
        self.img_tensor = self.image_processor.preprocess(self.conv_img,
                                                          return_tensors='pt'
                                                          )['pixel_values'].half().cuda()

    def generate_answer(self, **kwargs) -> str:
        """Generate an answer from the current conversation."""
        raw_prompt = self.conv.get_prompt()
        input_ids = tokenizer_image_token(raw_prompt,
                                          self.tokenizer,
                                          IMAGE_TOKEN_INDEX,
                                          return_tensors='pt').unsqueeze(0).cuda()
        stopping = KeywordsStoppingCriteria([self.stop_key],
                                            self.tokenizer,
                                            input_ids)
        with torch.inference_mode():
            output_ids = self.model.generate(input_ids,
                                             images=self.img_tensor,
                                             stopping_criteria=[stopping],
                                             **kwargs)
        outputs = self.tokenizer.decode(
            output_ids[0, input_ids.shape[1]:]
        ).strip()
        self.conv.messages[-1][-1] = outputs

        return outputs.rsplit('</s>', 1)[0]

    def get_conv_text(self) -> str:
        """Return full conversation text."""
        return self.conv.get_prompt()

    def start_new_chat(self,
                       img_path: str,
                       prompt: str,
                       do_sample=True,
                       temperature=0.2,
                       max_new_tokens=1024,
                       use_cache=True,
                       **kwargs) -> str:
        """Start a new chat with a new image."""
        conv_mode = "v1"
        self.setup_image(img_path)
        self.conv = conv_templates[conv_mode].copy()
        self.roles = self.conv.roles
        first_input = (DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN +
                       DEFAULT_IM_END_TOKEN + '\n' + prompt)  # f"{self.roles[0]}: {prompt}")
        self.conv.append_message(self.roles[0], first_input)
        self.conv.append_message(self.roles[1], None)
        if self.conv.sep_style == SeparatorStyle.TWO:
            self.stop_key = self.conv.sep2
        else:
            self.stop_key = self.conv.sep
        answer = self.generate_answer(do_sample=do_sample,
                                      temperature=temperature,
                                      max_new_tokens=max_new_tokens,
                                      use_cache=use_cache,
                                      **kwargs)
        return answer

    def continue_chat(self,
                      prompt: str,
                      do_sample=True,
                      temperature=0.2,
                      max_new_tokens=1024,
                      use_cache=True,
                      **kwargs) -> str:
        """Continue the existing chat."""
        if self.conv is None:
            raise RuntimeError("No existing conversation found. Start a new"
                               "conversation using the `start_new_chat` method.")
        self.conv.append_message(self.roles[0], prompt)
        self.conv.append_message(self.roles[1], None)
        answer = self.generate_answer(do_sample=do_sample,
                                      temperature=temperature,
                                      max_new_tokens=max_new_tokens,
                                      use_cache=use_cache,
                                      **kwargs)
        return answer

如果您熟悉transformers库,您将会认识到许多常见功能,并且执行的操作应该很容易理解。让我们快速浏览上面定义的 LLaVAChatBot 类的方法。

  • load_models: 该方法加载语言模型、分词器和图像处理器,使用 BitsAndBytes 库进行量化,以减少 GPU 内存需求。该代码阴影了 Hugging Face transformers 模型使用的 from_pretrained 方法。BitsAndBytes 允许将模型量化为 8 位或 4 位。

  • setup_image: 它从本地路径或 URL 加载图像,并使用图像处理器将其转换为张量。

  • generate_answer: 该方法返回模型关于所提供图像当前对话的答案。LLaVa 模型的 generate 方法类似于 Hugging Face transformers 模型的 generate 方法。

  • get_conv_text: 该方法返回到目前为止对话的原始文本。

  • start_new_chat: 这是聊天机器人的两个主要方法之一,用于与模型开始新的聊天。它创建一个新的对话,给定图像和初始提示传递给模型。它通过使用存储库中定义的模板设置对话,遵循前面部分讨论的格式。

  • continue_chat: 另一个主要方法,它继续关于图像的现有对话。

我在此 Colab 笔记本中提供了完整的代码,这样你就可以自己尝试。

示例

在本节中,我收集了在测试模型能力时获得的一些有趣示例。本文中的所有示例都是使用 8 位量化的模型 llava-v1.5–7b 创建的。

首先,让我们看看该模型描述和理解图像的能力。下面我要求模型描述一张描绘白老虎的图片。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由Joshua J. Cotten提供,来自Unsplash

A chat between a curious user and an artificial intelligence assistant. The 
assistant gives helpful, detailed, and polite answers to the user's questions. 

USER: <im_start><image><im_end>
Describe the image in details. 

ASSISTANT: The image features a white tiger laying down on a rocky surface, 
possibly a stone wall or a rock. The tiger appears to be relaxed and 
comfortable in its environment. The tiger's position is such that it covers a 
significant portion of the image, from the left side to almost the entire right
side. The rocky surface provides a natural and rugged backdrop for the scene.</s></s>

USER: What is unique about this tiger? 

ASSISTANT: This tiger is unique because it is a white tiger, which is an 
extremely rare and visually striking variation of the species. White tigers are
born with a genetic mutation that causes their fur to be predominantly white, 
with the rest of their body being a mix of orange and black stripes. This 
distinctive appearance sets them apart from the more common orange and black 
striped tigers.</s></s>

模型似乎能够详细描述图像,并且能够推理图像的主题,正确评估白老虎是一种罕见的情况。

接下来,我测试了模型是否能够为用户提供执行各种任务的可操作指导。这是区分“字幕生成”模型和真正的视觉对话助手的关键点。例如,模型能否根据图像中存在的成分提供食谱?结果如下报告。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由ThermoPro提供,来自Unsplash

A chat between a curious user and an artificial intelligence assistant. The 
assistant gives helpful, detailed, and polite answers to the user's questions. 
USER: <im_start><image><im_end>
What can I cook with these ingredients? 

ASSISTANT: With the ingredients displayed on the table, you can prepare a 
delicious roast dinner. The meat, which is a large piece of beef, can be 
seasoned and cooked in the oven, while the vegetables, including carrots, 
potatoes, and onions, can be roasted alongside the meat. You can also prepare a
side dish with the potatoes, carrots, and onions, which can be served alongside
the roast dinner. The presence of a bowl of sauce suggests that you can use it 
as a dipping sauce or to enhance the flavors of the roast and vegetables. 
Overall, this combination of ingredients offers a hearty and flavorful meal 
that can be enjoyed with family and friends.</s></s>

USER: Could the recipe you suggested be included in a vegetarian diet? 

ASSISTANT: The recipe I suggested, which involves roasting a large piece of 
beef and vegetables, is not suitable for a vegetarian diet. The beef is the 
main protein source in the dish, and it is not suitable for vegetarians who 
avoid consuming meat. However, you can still create a vegetarian-friendly meal 
using the vegetables and sauce availableon the table. For example, you can 
prepare a vegetable roast with the carrots, potatoes, and onions, and serve it 
alongside a side of mashed potatoes or a vegetable-based sauce. This way, you 
can still enjoy the flavors and textures of the dish without including meat.</s></s>

USER: Can you suggest a vegetarian recipe with minimal modifications of the 
depicted ingredients? 

ASSISTANT: Certainly! Here's a vegetarian recipe that can be made using the 
ingredients available on the table:

Roasted Vegetable Medley with Sauce

Ingredients:

* Carrots
* Potatoes
* Onions
* Olive oil
* Salt
* Pepper
* Herbs (such as thyme, rosemary, or parsley)
* Sauce (optional)

Instructions:

1\. Preheat your oven to 400°F (200°C).
2\. Wash and peel the carrots, potatoes, and onions. Cut them into bite-sized 
pieces.
3\. In a large bowl, toss the vegetables with olive oil, salt, and pepper.
4\. Spread the vegetables evenly on a baking sheet lined with parchment paper.
5\. Roast the vegetables in the preheated oven for 20-25 minutes, or until they 
are tender and slightly browned.
6\. If desired, prepare a sauce by combining ingredients such as tomatoes, 
garlic, and herbs. You can use this sauce as a dipping sauce or to enhance the
flavors of the roasted vegetables.
7\. Serve the roasted vegetable medley alongside a side of mashed potatoes or a 
vegetable-based sauce.

This vegetarian recipe allows you to enjoy the flavors and textures of the 
roasted vegetables without including meat. The sauce can be easily customized 
to suit your taste preferences, and the dish can be served alongside a variety 
of side dishes, such as mashed potatoes or a vegetable-based sauce.</s></s>

模型似乎对图像中存在的成分有很好的理解,并且能够将单个词语连接到更抽象的概念(即肉类餐品不是素食)。

现在让我们看看模型如何用于获得建议。这是在头脑风暴时获取新想法的有用功能,并且可以为用户提供真正的价值。在下面的对话中,我请求模型提供关于改变咖啡馆外观的建议。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由Michał Mancewicz提供,来自Unsplash

A chat between a curious user and an artificial intelligence assistant. The 
assistant gives helpful, detailed, and polite answers to the user's questions. 

USER: <im_start><image><im_end>
How can I renovate this café to make it look more vintage? 

ASSISTANT: To make the café look more vintage, you can consider the following 
changes:

1\. Update the lighting: Replace the current lighting with vintage-style lamps 
or lanterns, which will create a warm and cozy atmosphere.

2\. Add vintage decorations: Incorporate antique items, such as vintage clocks, 
artwork, or decorative plates, to create a nostalgic ambiance.

3\. Use vintage furniture: Replace the current stools with vintage-style chairs 
or barstools, which can add character and charm to the space.

4\. Incorporate greenery: Add more greenery, such as potted plants or vines, to 
create a lush and inviting atmosphere.

5\. Update the bar: Replace the current bar with a vintage-style bar, complete 
with a vintage-inspired backsplash and countertop.

6\. Use vintage-style glassware: Replace the current glassware with vintage-style
glasses, such as wine glasses, beer mugs, or cocktail shakers.
7\. Add vintage-style signage: Replace the current signage with vintage-style 
menu boards or chalkboards, which can add a touch of nostalgia to the space.

8\. Use vintage-style wallpaper or paint: Apply a vintage-style wallpaper or 
paint to the walls, which can create a unique and charming atmosphere.

By implementing these changes, the café will appear more vintage and inviting, 
creating a memorable experience for its patrons.</s></s>

建议看起来切合要点,尽管相当通用,识别出可以更改的当前元素,以赋予更复古的风格。

让我们继续实际的内容,可以帮助用户在日常活动中获得帮助。模型是否能够帮助撰写广告或列表?

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由Eugene Chystiakov提供,来自Unsplash

A chat between a curious user and an artificial intelligence assistant. The 
assistant gives helpful, detailed, and polite answers to the user's questions. 

USER: <im_start><image><im_end>
Write an advertisment for this chair 

ASSISTANT: Introducing the modern and stylish black and wood chair, perfect for
any contemporary setting. This chair features a sleek black seat and a wooden 
base, creating a striking contrast against any background. Its minimalist 
design makes it a versatile addition to any room, whether it's used as a dining
chair, a side table, or a statement piece in a living space. The chair's sturdy
construction ensures it can withstand daily use, while its clean lines and 
sleek appearance add a touch of sophistication to any room. Upgrade your living
space with this stylish and functional chair today.</s></s>

USER: Write now an advertisement for the same chair but targeting cafes 

ASSISTANT: Attention café owners and design enthusiasts! Introducing the black
and wood chair, the perfect addition to your café's seating area. This stylish
and functional chair offers a modern touch to your establishment, with its 
sleek black seat and wooden base creating a striking contrast against any 
background. Its minimalist design makes it a versatile choice for any café 
setting, whether it's used as a dining chair, a side table, or a statement 
piece in a lounge area. The chair's sturdy construction ensures it can 
withstand daily use, while its clean lines and sleek appearance add a touch of 
sophistication to any room. Upgrade your café's ambiance with this stylish and 
functional chair today.</s></s>

模型能够准确指出椅子的特征,并且能够创建吸引人的广告,当要求时正确地定制答案以适应不同的目标。

作为最后的测试,让我们看看是否可以像大型语言模型那样使用已被证明对 LLaVA 有益的提示技术。以下是尝试零点启发的链式思维提示的示例。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由Arnold Middelkoop提供,来自Unsplash

A chat between a curious user and an artificial intelligence assistant. The 
assistant gives helpful, detailed, and polite answers to the user's questions. 

USER: <im_start><image><im_end>
Would the fish depicted in the image float if thrown in water? Think step by 
step. 

ASSISTANT: The fish depicted in the image is a large white fish sculpture or 
statue, which is lying on the ground on a brick walkway near a body of water. 
If thrown into the water, the fish would not float because it is a solid, 
inanimate object. It would sink to the bottom of the water, as it is not 
buoyant due to its composition and size. The fish's position on the ground and 
its lack of buoyancy make it impossible for it to float in water.</s></s>

USER: What material could be used to make a fish sculpture that floats on 
water? 

ASSISTANT: To make a fish sculpture that floats on water, one could use a 
lightweight, buoyant material such as expanded polystyrene foam or polyethylene
foam. These materials are commonly used for making buoyant objects like life 
jackets, flotation devices, and fishing floats. By shaping the foam into the 
desired fish form, it would be able to float on the water's surface, 
simulating the behavior of a real fish.</s></s>

该模型能够解释为什么鱼不会浮起,正确地识别出它不是真正的鱼。它还遵循了在给出最终答案之前逐步思考的指示。作为比较,这是一个没有“逐步思考”提示的对话示例。

A chat between a curious user and an artificial intelligence assistant. The 
assistant gives helpful, detailed, and polite answers to the user's questions.

USER: <im_start><image><im_end>
Would the fish depicted in the image float if thrown in water? 

ASSISTANT: No, the fish depicted in the image would not float if thrown in 
water. It is a large, white, stone fish sculpture, which is not a real fish 
and does not have the buoyancy properties of a natural fish.</s></s>

虽然最终答案仍然是正确的,但在这种情况下给出的解释要少得多,并且答案是在句子开头给出的。由于 LLaVA 使用的生成语言模型是因果的,这意味着提供的答案不依赖于后续的解释。进行更广泛的实验将很有趣,以测试 Chain of Thought(CoT)是否通常能够改善视觉语言模型在复杂推理任务中的表现,类似于观察到的大型语言模型在Chain-of-Thought Prompting Elicits Reasoning in Large Language ModelsLarge Language Models are Zero-Shot Reasoners中所述的情况。

限制

尽管非常强大,LLaVA 也有其局限性。例如,该模型已经训练成每次只使用一张图片,因此无法处理需要与多张图片交互的更复杂对话。然而,值得注意的是,这种限制并非模型架构固有的,可以通过在训练期间包含多图片对话数据来修正。

在测试中,我注意到该模型(至少是 1.5 7b 模型的 8 位量化版本)在计算图像中的元素数量以及处理数值任务方面存在困难。以下是一个例子。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由John Matychuk拍摄,来自Unsplash

A chat between a curious user and an artificial intelligence assistant. The 
assistant gives helpful, detailed, and polite answers to the user's questions. 

USER: <im_start><image><im_end>
How many car are inside the parking lot? 

ASSISTANT: There are several cars parked inside the parking lot.</s></s>

USER: List the numbers of the parking spots occupied by cars 

ASSISTANT: The parking spots occupied by cars are 66, 67, 68, 69, 70, and 71.</s></s>

像所有大型语言模型一样,该模型也容易产生幻觉,并且可以通过适当的提示轻易地被欺骗而犯错。例如,如果提供了沙漠的图片,并问这张图片中的鱼是哪种物种,它会产生沙漠鱼的幻觉。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

照片由Wolfgang Hasselmann拍摄,来自Unsplash

A chat between a curious user and an artificial intelligence assistant. The 
assistant gives helpful, detailed, and polite answers to the user's questions. 

USER: <im_start><image><im_end>
Tell me the species of the fish in this image 

ASSISTANT: The fish in this image is a desert fish, which is a species adapted 
to living in desert environments.</s></s>

结论

LLaVA 在视觉语言理解方面显示出令人印象深刻的能力。这标志着多模式开源视觉语言模型的明显进步。LLaVA 最大的优势之一是它易于训练和微调。例如,LLaVA 1.5 13b 的完整训练仅使用了 1.2M 数据,在单个 8-A100 节点上大约耗时 1 天。这使其适合在特定领域进行微调,以获得专家助理,例如在LLaVA-Med: Training a Large Language-and-Vision Assistant for Biomedicine in One Day中所做的那样。

为聊天助手添加视觉能力扩展了此类模型的应用领域,使其革命性潜力能够应用于更复杂和更细致的任务。将图像特征视为语言标记也带来了使用所有先进提示技术的可能性,这些技术通常用于纯文本语言模型,并进一步扩展了这些技术。例如,可以通过检索与对话相关的文本和图像来扩展检索增强生成的能力。实际上,利用 CLIP 的共享图像文本嵌入空间,可以通过输入文本或图片来检索外部文档和外部图像!

扩展模型能力的另一个有趣方向可以在LLaVA-Interactive: An All-in-One Demo for Image Chat, Segmentation, Generation and Editing中找到。主要思路是结合视觉语言聊天模型、文本生成图像模型以及其他视觉模型(如图像分割模型)的各种能力,以获得一个能够处理多模态输入并生成多模态输出的助手。

总之,LLaVA 为开源多模态生成模型标志着一个重要的步骤,这些模型展现了令人印象深刻的能力,并引起了广泛关注。随着开源模型的更广泛采用,我相信我们很快将见证这些强大模型的新应用的迅速增长。

感谢阅读!如果你想亲自尝试代码,可以查看这个 Colab 笔记本

从 RGB 视频创建 3D 视频

原文:towardsdatascience.com/creating-3d-videos-from-rgb-videos-491a09fa1e79?source=collection_archive---------3-----------------------#2023-08-03

生成一致的深度图和点云视频的指南

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Berkan Zorlubas

·

关注 发表在 Towards Data Science ·8 min read·2023 年 8 月 3 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

作者提供的图片(视频帧的编辑版本,来自 库存镜头 提供者 Videvo,下载自 www.videvo.net

我一直对我们将数字记忆以 2D 格式存档这一事实感到不满——尽管照片和视频很清晰,却缺乏它们所捕捉的经历的深度和沉浸感。在机器学习模型足够强大到理解照片和视频的 3D 感的时代,这似乎是一个任意的限制。

来自图像或视频的 3D 数据不仅使我们能够更生动、互动地体验记忆,还提供了编辑和后处理的新可能性。试想能够轻松地从场景中移除对象、更换背景,甚至改变视角以从新的视点查看一个瞬间。深度感知处理也为机器学习算法提供了更丰富的上下文来理解和处理视觉数据。

在寻找生成一致深度视频的方法时,我发现了一篇研究论文,它建议了一种很好的方法。这种方法涉及使用整个输入视频训练两个神经网络:一个卷积神经网络(CNN)来预测深度,一个 MLP 来预测场景中的运动或“场景流”。这个流预测网络以特殊的方式应用,在不同时间段内重复应用。这使它能够识别场景中的小变化和大变化。小变化有助于确保 3D 中的运动从一个时刻到下一个时刻是平滑的,而大变化有助于确保整个视频在不同视角下是一致的。这样,我们可以创建既局部又全球准确的 3D 视频。

这篇论文的代码库是公开的,但处理任意视频的管道并没有完全解释,至少对我来说,如何使用所提议的管道处理任何视频仍然不清楚。在这篇博客文章中,我将尝试填补这个空白,并逐步介绍如何在你的视频上使用这个管道。

你可以查看我在GitHub页面上的代码版本,我将会参考这个版本。

第 1 步:从视频中提取帧

在管道中的第一步是从选择的视频中提取帧。我添加了一个脚本用于这个目的,你可以在scripts/preprocess/custom/extract_frames_from_video.py中找到。要运行代码,只需在终端中使用以下命令:

python extract_frames_from_video.py ^
  -- video_path = 'ENTER YOUR VIDEO PATH HERE' ^
  -- output_dir = '../../../datafiles/custom/JPEGImages/640p/custom/' ^
  -- resize_factor = 0.5

使用resize_factor参数,你可以对帧进行下采样或上采样。

我选择了这个视频进行测试。最初,它的分辨率是 1280x720,但为了加快后续步骤的处理速度,我将其缩小到 640x360,使用了 0.5 的resize_factor

第 2 步:在视频中分割前景对象

我们过程中的下一步需要对视频中的一个主要前景物体进行分割或隔离,这对估计相机在视频中的位置和角度至关重要。原因是?离相机较近的物体对姿态估计的影响比离相机较远的物体要大。举个例子,想象一个距离 1 米远的物体移动 10 厘米——这将导致图像发生较大的变化,可能是几十个像素。但如果同样的物体距离 10 米远且移动相同的距离,图像变化则不那么明显。因此,我们生成了一个‘遮罩’视频,以便关注对姿态估计相关的区域,简化我们的计算。

我偏好使用Mask-RCNN来分割帧。你也可以使用其他你喜欢的分割模型。对于我的视频,我决定对右侧人物进行分割,因为他在整个视频中都出现在画面中,并且看起来离相机足够近。

要生成遮罩视频,需要对你的视频进行一些特定的手动调整。由于我的视频中包含两个人物,我首先对这两个人物进行了遮罩分割。之后,我通过硬编码提取了右侧人物的遮罩。根据你选择的前景物体及其在视频中的位置,你的方法可能会有所不同。负责创建遮罩的脚本可以在./render_mask_video.py中找到。我指定遮罩选择过程的脚本部分如下:

 file_names = next(os.walk(IMAGE_DIR))[2]
    for index in tqdm(range(0, len(file_names))):
        image = skimage.io.imread(os.path.join(IMAGE_DIR, file_names[index]))
        # Run detection
        results = model.detect([image], verbose=0)
        r = results[0]
        # In the next for loop, I check if extracted frame is larger than 16000 pixels, 
        # and if it is located minimum at 250th pixel in horizontal axis.
        # If not, I checked the next mask with "person" mask.
        current_mask_selection = 0
        while(True):
            if current_mask_selection<10:
                if (np.where(r["masks"][:,:,current_mask_selection]*1 == 1)[1].min()<250 or 
                    np.sum(r["masks"][:,:,current_mask_selection]*1)<16000):
                    current_mask_selection = current_mask_selection+1
                    continue
                elif (np.sum(r["masks"][:,:,current_mask_selection]*1)>16000 and 
                      np.where(r["masks"][:,:,current_mask_selection]*1 == 1)[1].min()>250):
                    break
            else:
                break
        mask = 255*(r["masks"][:,:,current_mask_selection]*1)
        mask_img = Image.fromarray(mask)
        mask_img = mask_img.convert('RGB')
        mask_img.save(os.path.join(SAVE_DIR, f"frame{index:03}.png"))

原始视频和遮罩视频在以下动画中并排显示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(左) 由 Videvo 提供的库存视频,从www.videvo.net下载 | (右) 作者创建的遮罩视频

步骤 3: 估计相机姿态和内部参数

创建遮罩帧后,我们现在开始计算相机姿态和内部估计。为此,我们使用一个叫做Colmap的工具。它是一个多视图立体视觉工具,可以从多个图像创建网格,并估计相机的移动和内部参数。它既有图形用户界面,也有命令行界面。你可以从这个链接下载该工具。

启动工具后,你需要点击顶部栏上的“重建”(见下图),然后选择“自动重建”。在弹出窗口中,

  • 进入./datafiles/custom/triangulation到“工作空间文件夹”

  • 进入./datafiles/custom/JPEGImages/640p/custom到“图像文件夹”

  • 进入./datafiles/custom/JPEGImages/640p/custom到“图像文件夹”

  • 进入./datafiles/custom/Annotations/640p/custom到“遮罩文件夹”

  • 勾选“共享内部参数”选项

  • 点击“运行”。

计算可能需要一些时间,具体取决于你有多少图像以及图像的分辨率。计算完成后,点击“文件”下的“将模型导出为文本”并将输出文件保存在./datafiles/custom/triangulation中。这将创建两个文本文件和一个网格文件(.ply)。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Colmap 的说明 — 图片由作者提供

这一步还没有结束,我们需要处理 Colmap 的输出。我编写了一个脚本来自动化这一过程。只需在终端中运行以下命令:

python scripts/preprocess/custom/process_colmap_output.py

它将创建“custom.intrinsics.txt”、“custom.matrices.txt”和“custom.obj”。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Colmap 的输出文件 — 图片由作者提供

现在我们准备好进行训练的数据集生成。

步骤 4:为训练准备数据集

训练需要一个数据集,该数据集包括每帧的深度估计由MiDas提供,相应的跨帧光流估计和深度序列。创建这些的脚本在原始仓库中提供,我只是更改了其中的输入和输出目录。通过运行下面的命令,将创建所有所需的文件并放置在适当的目录中:

python scripts/preprocess/custom/generate_frame_midas.py  &
python scripts/preprocess/custom/generate_flows.py  &
python scripts/preprocess/custom/generate_sequence_midas.py 

在训练之前,请检查datafiles/custom_processed/frames_midas/customdatafiles/custom_processed/flow_pairs/customdatafiles/custom_processed/sequences_select_pairs_midas/custom中是否存在 .npz 和 .pt 文件。验证后,我们可以继续进行训练。

步骤 5:训练

训练部分很简单。要用自定义数据集训练神经网络,只需在终端中运行以下命令:

python train.py --net scene_flow_motion_field ^
 --dataset custom_sequence --track_id custom ^
 --log_time  --epoch_batches 2000 --epoch 10 ^
 --lr 1e-6 --html_logger --vali_batches 150  ^
 --batch_size 1 --optim adam --vis_batches_vali 1 ^
 --vis_every_vali 1 --vis_every_train 1 ^
 --vis_batches_train 1 --vis_at_start --gpu 0 ^
 --save_net 1 --workers 1 --one_way ^
 --loss_type l1 --l1_mul 0 --acc_mul 1 ^
 --disp_mul 1 --warm_sf 5 --scene_lr_mul 1000 ^
 --repeat 1 --flow_mul 1 --sf_mag_div 100 ^
 --time_dependent --gaps 1,2,4,6,8 --midas ^
 --use_disp --logdir 'logdir/' ^
 --suffix 'track_{track_id}' ^
 --force_overwrite

经过 10 次迭代的神经网络训练后,我观察到损失开始饱和,因此决定不再继续训练更多的迭代。以下是我的训练损失曲线图:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

损失 vs. 迭代曲线 — 图片由作者提供

在训练过程中,所有检查点都保存在目录./logdir/nets/中。此外,每次迭代后,训练脚本会在目录./logdir/visualize中生成测试可视化。这些可视化可以特别有助于识别训练过程中可能出现的任何潜在问题,除了监控损失外。

步骤 6:使用训练好的模型创建每帧的深度图

使用最新的检查点,我们现在用test.py脚本生成每帧的深度图。只需在终端中运行以下命令:

python test.py --net scene_flow_motion_field ^
 --dataset custom_sequence --workers 1 ^
 --output_dir .\test_results\custom_sequence ^
 --epoch 10 --html_logger --batch_size 1 ^
 --gpu 0 --track_id custom --suffix custom ^
 --checkpoint_path .\logdir

这将为每一帧生成一个 .npz 文件(一个包含 RGB 帧、深度、相机姿态、流向下一张图像等的字典文件),以及每一帧的三个深度渲染(真实值、MiDaS 和训练网络的估计)。

步骤 7:创建点云视频

在最后一步,我们逐帧加载批处理的 .npz 文件,并利用深度和 RGB 信息创建彩色点云。我使用 open3d 库在 Python 中创建和渲染点云。这是一个强大的工具,你可以在 3D 空间中创建虚拟相机并用它们捕捉点云。你还可以编辑/操作点云;我应用了 open3d 内置的异常点移除功能来去除闪烁和噪声点。

虽然我不会详细讨论我使用 open3d 的具体细节,以保持本文的简洁性,但我已包含了脚本 render_pointcloud_video.py,该脚本应当易于理解。如果你有任何问题或需要进一步澄清,请随时问我。

这是我处理的视频的点云和深度图像的视频效果。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

(左) 视频素材由 Videvo 提供,下载自 www.videvo.net | (右) 作者制作的深度图视频 | (下) 作者制作的彩色点云视频

此动画的高分辨率版本已上传至 YouTube

好吧,深度图和点云很酷,但你可能想知道你可以用它们做什么。与传统的效果添加方法相比,深度感知效果可以非常强大。例如,深度感知处理可以创建各种电影效果,否则很难实现。通过视频的估计深度,你可以无缝地加入合成的相机对焦和虚焦,产生真实且一致的散景效果。

此外,深度感知技术提供了实现动态效果如“推镜变焦”的可能性。通过调整虚拟相机的位置和内参,这种效果可以生成惊艳的视觉序列。此外,深度感知的对象插入确保虚拟对象在视频中真实固定,保持整个场景中的一致位置。

深度图和点云的结合为引人入胜的叙事和富有创意的视觉效果开辟了无限可能,推动了电影制作人和艺术家的创意潜力达到新的高度。

点击本文的“发布”按钮后,我将挽起袖子开始制作这些效果。

祝你有美好的一天!

使用 Spark、Google Cloud Storage 和 Big Query 创建数据管道

原文:towardsdatascience.com/creating-a-data-pipeline-with-spark-google-cloud-storage-and-big-query-a72ede294f4c?source=collection_archive---------7-----------------------#2023-03-06

本地和云端协同工作以交付数据产品

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 João Pedro

·

关注 发布于 Towards Data Science · 10 分钟阅读 · 2023 年 3 月 6 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Toro Tseleng 提供,来源于 Unsplash

开发数据管道在某种程度上类似于玩乐高,你需要构思实现目标(数据需求),选择合适的零件(软件、工具、平台),然后将它们组装在一起。就像在乐高中一样,构建过程的复杂性取决于最终目标的复杂性。

从使用 Python 构建的简单 ETL 管道在两个数据库之间移动数据,到使用 Kafka 在各种云结构之间流式传输实时消息以服务多个最终应用程序,非常复杂的结构都是可能的。

但现实是,当前的数据场景更像那些昂贵的专业乐高套件,拥有各种解决特定需求的零件,并且新的零件在每个角落不断出现。你可能已经看过 Matt Turck 的 2021 年机器学习、人工智能和数据 (MAD) 领域。而糟糕的部分——说明书没有包含在内。

过去十年中,许多开源数据相关工具如 Spark、Hadoop 和 Kafka 已经被开发出来,更不用说 Python 库中所有可用的工具了。这些就是我喜欢在文章中覆盖的工具,它们是免费的,维护良好的 Docker 镜像,并且我可以使用 Docker 开发自包含的应用程序,任何人都可以在任何地方运行。

但是,随着数据领域的成熟,所有的箭头似乎都指向同一个方向——云计算。这一点毫不令人惊讶。专注于数据应用的公司如 Databricks、DBT 和 Snowflake 正在迅速流行,而传统玩家(AWS、Azure 和 GCP)也在大力投资他们的数据产品。

这就是今天文章的目标——我们将使用 Apache Spark、Google Cloud Storage 和 Google Big Query(使用免费层)来开发一个数据管道。

未经赞助。

工具

Spark 是一个面向处理极大量数据的通用分布式内存数据处理框架。我在许多其他文章中介绍过 Spark。

Google Cloud Storage (GCS) 是 Google 的对象存储。概念很简单:创建一个桶并将文件存储在其中。稍后使用它们的“路径”读取。文件夹是虚假的,对象是不可变的。

Google Big Query (GBQ) 是 Google 的云数据仓库解决方案。一个以 OLAP 为重点的数据库,具有无服务器 SQL 查询执行能力,能够处理大量数据。

数据

我们将构建一个数据管道来处理和存储来自巴西“高等教育”(字面翻译)普查的数据。该普查每年收集有关巴西高等教育机构(主要是大学)从不同角度的许多统计数据:机构的、社会和人口统计的、地理的等等。

我们将处理 课程报告,其中包含有关每个巴西高等教育课程(本科、研究生、博士等)的统计数据。这些数据公开可用 [CC BY-ND 3.0],以 CSV 文件形式提供(每年一个)。

实现

管道的想法很简单,将 CSV 文件下载到本地机器,转换为存储在 GCS 桶中的 Delta-Lake 表,对这个 delta 表进行所需的转换,并将结果保存到一个 Big Query 表中,以便其他下游任务可以轻松使用。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

提议的管道。作者提供的图像。

存储桶将作为原始文件存储,聚合所有报告到一个地方。BigQuery 表将存储我们准备好的数据,已经过滤、聚合,并且只包含有用的列。

如前所述,普查收集了来自所有高等教育机构的大量统计数据,包括但不限于大学。为了模拟“真实情况”,假设我们需要创建一个表格来回答关于每年进入大学的新生的各种社会/人口统计问题。

0. 设置环境

所有代码都可以在这个 GitHub 仓库中找到。

你需要在本地计算机上安装 docker 以创建 Spark 集群,并且需要 python 来下载文件。

docker-compose 文件:

version: '3'

services:
  spark:
    build: .
    environment:
      - SPARK_MODE=master
    ports:
      - '8080:8080'
      - '4040:4040'
    volumes:
      - ./data:/data
      - ./src:/src
  spark-worker:
    build: .
    environment:
      - SPARK_MODE=worker
      - SPARK_MASTER_URL=spark://spark:7077
      - SPARK_WORKER_MEMORY=4G
      - SPARK_EXECUTOR_MEMORY=4G
      - SPARK_WORKER_CORES=4
    volumes:
      - ./data:/data
      - ./src:/src 

spark Dockerfile:

FROM docker.io/bitnami/spark:3.3.1

COPY *.jar $SPARK_HOME/jars

RUN mkdir -p $SPARK_HOME/secrets
COPY ./src/credentials/gcp-credentials.json $SPARK_HOME/secrets/gcp-credentials.json
ENV GOOGLE_APPLICATION_CREDENTIALS=$SPARK_HOME/secrets/gcp-credentials.json

RUN pip install delta-spark

docker 镜像已经配置为从头自动创建一个新环境,因此我们可以更多地关注实现而不是配置。

当然,你需要创建一个 Google Cloud Platform 账户。尽管我们只会使用免费的配额,但仍然需要你的信用卡信息。GCP 表示除非你明确结束免费试用期,否则不会收费,但请小心。

创建账户后,按照以下步骤操作:

1. 访问 GCP 控制台并创建一个新项目。我将我的项目命名为“BigQueryFirstSteps”

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

2. 在 API & Services 标签中授权 Google Cloud Storage 和 BigQuery 的 API。

3. 在 Google Cloud Storage 中创建一个名为censo-ensino-superior的新存储桶

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

4. 在 Google Big Query 中创建一个名为censo-ensino-superior的新数据集

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

5. 在 IAM & Administrator 标签页中的服务账户项目下创建一个新的服务账户,并分配适当的角色以读取、写入和创建 GCP 存储桶和 GBQ 表(我使用了BigQuery 管理员存储管理员角色)

6. 在此页面上,生成一个新的访问密钥(JSON)用于新创建的账户。密钥将下载到你的本地计算机。

返回本地环境,执行prepare_env.sh 脚本。

mkdir -p ./data/
mkdir -p ./src/credentials
chmod -R 777 ./src
chmod -R 777 ./data

wget https://storage.googleapis.com/hadoop-lib/gcs/gcs-connector-latest-hadoop2.jar

它创建了几个具有特定授权的文件夹(以便 spark 容器可以从中读取和写入)并下载了 GCS 连接器以供 spark 使用。

现在,将你的 JSON 凭证文件重命名为gcp-credentials.json并放置在 ./src/credentials 文件夹中(与 bucket_name.txt 文件一起)。

最后,启动容器:

docker compose up --build

1. 下载数据

只需运行脚本:

python download_files.py

CSV 文件将下载到 ./data 文件夹中。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

2. 将 CSV 转换为 GCS 中的 Delta Lake

首先需要做的是实例化一个 Spark 会话,并配置 Delta-Lake 依赖项。

from pyspark.sql import SparkSession
from pyspark.sql import functions as F

from delta import configure_spark_with_delta_pip

MASTER_URI = "spark://spark:7077"

if __name__ == "__main__":
    # spark-submit --packages io.delta:delta-core_2.12:2.1.0 --master spark://spark:7077 insert_csv_into_delta_table_gcs.py

    builder = SparkSession.builder\
        .master(MASTER_URI)\
        .appName("Insert CSV Censo into Delta Table")\
        .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
        .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog")\

    spark = configure_spark_with_delta_pip(builder).getOrCreate()

读取下载的 CSV 文件非常简单,只需指定选项并给出路径。

# Read the CSV file
df_cursos = (
    spark.read
    .format("csv")
    .option("header", "true")
    .option("delimiter", ";")
    .option("encoding", "ISO-8859-1")
    .option("inferSchema", "true")
    .load("/data/MICRODADOS_CADASTRO_CURSOS_*.CSV") # read all csv files
)

为了在 GCP 桶中写入更少的数据,我还仅选择了对最终表有用的列:普查年份;课程标识;知识领域;地点和类型;按性别、年龄和肤色统计的新生人数。

# Select Columns
df_cursos = df_cursos.select(
    [   
        # Year
        "NU_ANO_CENSO",
        # Course AND Institution
        "CO_IES",
        "NO_CURSO",
        "CO_CURSO",
        # Total of new students
        "QT_ING",
        # Age
        "QT_ING_0_17",
        # ...
        # Skin Color
        "QT_ING_BRANCA",
        "QT_ING_PRETA",
        # ...
        # Gender COLUMNS
        # Place COLUMNS
        # Area of Knowledge (CINE) COLUMNS
        # FIELDS OMITTED TO MAKE THIS CODE BLOCK SMALLER
    ]
)

# cast columns to the correct type
for col in df_cursos.columns:
    if col in ["NU_ANO_CENSO"] or col.startswith("QT_"):
        df_cursos = df_cursos.withColumn(col, df_cursos[col].cast(IntegerType()))
    elif col.startswith("IN_"):
        df_cursos = df_cursos.withColumn(col, df_cursos[col].cast(BooleanType()))
    else:
        df_cursos = df_cursos.withColumn(col, df_cursos[col].cast(StringType()))

将数据写入 GCP 桶就像写入文件系统一样,但我们需要用自己的语法指定桶路径:“gs://<bucket_name>/<filepath_to_be_create>”。

df_cursos.write\
        .format("delta")\
        .partitionBy(["NU_ANO_CENSO"])\
        .mode("overwrite")\
        .save("gs://censo-ensino-superior/cens_cursos")

上述代码在 censo-ensino-superior 桶内创建了一个名为 censo_cursos 的新 Delta 表。

我们不需要在代码中处理身份验证,因为凭据在 Docker 构建阶段已经正确配置。

要执行此脚本,请访问 Spark 容器的终端(主节点或工作节点)并执行:

# cd into the folder where the script is stored
spark-submit --packages io.delta:delta-core_2.12:2.1.0 --master spark://spark:7077 insert_csv_into_delta_table_gcs.py

一分钟左右,脚本将完成,数据将可用在你的 GCS 桶中。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

3. 从 GCS 处理 Delta 表到 GBQ

首先需要做的是实例化一个 Spark 会话,与之前所做的相同。

从桶中读取数据也遵循与写入相同的逻辑。

df_censo = (
      spark.read
          .format("delta")
          .load("gs://censo-ensino-superior/cens_cursos")
)

数据准备好后,我们可以像往常一样进行一些转换。

 df_censo = (
        df_censo
        # Bachelor and Licenciatura TP_GRAU_ACADEMICO = 4
        .filter( 
            (F.col('TP_GRAU_ACADEMICO') == "1")  
            | (F.col('TP_GRAU_ACADEMICO') == "4")
        )
        # Group by CO_CINE_AREA_DETALHADA, CO_UF (STATE) and NU_ANO_CENSO (YEAR)
        .groupBy(
            'CO_CINE_AREA_DETALHADA', 'CO_UF', 'NU_ANO_CENSO'
        )
        .agg(
            F.max('NO_CINE_AREA_DETALHADA').alias('NO_CINE_AREA_DETALHADA'),

            F.max('NO_CINE_AREA_ESPECIFICA').alias('NO_CINE_AREA_ESPECIFICA'),
            F.max('CO_CINE_AREA_ESPECIFICA').alias('CO_CINE_AREA_ESPECIFICA'),
            F.max('NO_CINE_AREA_GERAL').alias('NO_CINE_AREA_GERAL'),
            F.max('CO_CINE_AREA_GERAL').alias('CO_CINE_AREA_GERAL'),

            F.max('SG_UF').alias('SG_UF'),
            F.max('NO_REGIAO').alias('NO_REGIAO'),
            F.max('CO_REGIAO').alias('CO_REGIAO'),

            F.count('CO_CURSO').alias('QT_CO_CURSO'),
            F.sum('QT_CURSO').alias('QT_CURSO'),
            F.sum('QT_VG_TOTAL').alias('QT_VG_TOTAL'),
            F.sum('QT_ING').alias('QT_ING'),

            F.sum('QT_ING_0_17').alias('QT_ING_0_17'),
            F.sum('QT_ING_18_24').alias('QT_ING_18_24'),
            F.sum('QT_ING_25_29').alias('QT_ING_25_29'),
            F.sum('QT_ING_30_34').alias('QT_ING_30_34'),
            F.sum('QT_ING_35_39').alias('QT_ING_35_39'),
            F.sum('QT_ING_40_49').alias('QT_ING_40_49'),
            F.sum('QT_ING_50_59').alias('QT_ING_50_59'),
            F.sum('QT_ING_60_MAIS').alias('QT_ING_60_MAIS'),

            F.sum('QT_ING_BRANCA').alias('QT_ING_BRANCA'),
            F.sum('QT_ING_PRETA').alias('QT_ING_PRETA'),
            F.sum('QT_ING_PARDA').alias('QT_ING_PARDA'),
            F.sum('QT_ING_AMARELA').alias('QT_ING_AMARELA'),
            F.sum('QT_ING_INDIGENA').alias('QT_ING_INDIGENA'),
            F.sum('QT_ING_CORND').alias('QT_ING_CORND'),

            F.sum('QT_ING_FEM').alias('QT_ING_FEM'),
            F.sum('QT_ING_MASC').alias('QT_ING_MASC'),
        )
    )

上述查询首先筛选出仅包含学士和学位课程(使用它们的代码),并按年份、详细领域和州对结果进行分组,汇总每个类别的新生人数(“QT_ING_XYZ”)。

要将数据写入 BigQuery 表,我们需要使用格式 “bigquery” 以及一些选项。显然,我们需要传递要写入的表和数据库。由于这个表尚不存在,因此需要将选项“createDisposition”设置为“CREATE_IF_NEEDED”。

标准的 GBQ-Spark 连接器使用 GCS 桶作为一个中间缓冲区,用于从 GBQ 传输数据。因此,我们需要传递一个“temporaryGcsBucket”选项,并指定桶名称。为了简单起见,我使用了之前创建的相同桶。

df_censo.write\
    .format("bigquery")\
    .mode("overwrite")\
    .option("temporaryGcsBucket", "censo-ensino-superior")\
    .option("database", "censo_ensino_superior")\
    .option("table", "censo_ensino_superior.cursos_graduacao_e_licenciatura")\
    .option("createDisposition", "CREATE_IF_NEEDED")\
    .save()

同时,请注意这次写入是以 mode=“overwrite” 模式进行的,如果表已经存在,它将覆盖任何以前的数据。如果你只想添加新行,请使用“append”模式。

就这些了。

要运行这个任务,只需输入:

# cd into the folder where the script is stored
spark-submit --packages io.delta:delta-core_2.12:2.1.0,com.google.cloud.spark:spark-3.1-bigquery:0.28.0-preview aggregate_delta_gcs_to_gbq_table.py

表将会被创建并填充,让我们来看看:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

为了举例说明,让我们运行一个查询。

下面的查询计算了每个知识领域男性和女性的百分比。

SELECT
  CO_CINE_AREA_GERAL,
  NO_CINE_AREA_GERAL,
  SUM(QT_ING_MASC)/SUM(QT_ING_FEM + QT_ING_MASC) AS PERCENT_MASC,
  SUM(QT_ING_FEM)/SUM(QT_ING_FEM + QT_ING_MASC) AS PERCENT_FEM  
FROM 
  `censo_ensino_superior.cursos_graduacao_e_licenciatura`
GROUP BY
  NO_CINE_AREA_GERAL,
  CO_CINE_AREA_GERAL
ORDER BY
  PERCENT_MASC

结果:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

带有注释翻译的查询结果。图片由作者提供。

结论

在这篇文章中,我们学习了如何使用 Apache Spark(本地部署)作为数据处理工具,并将 Google Cloud Storage(用于原始文件存储)和 Google Big Query(用于为分析查询提供处理后的数据)作为存储解决方案来开发数据管道。

使用 Spark 与云中的数据进行交互并没有那么特别。更困难的部分是配置环境:找到合适的连接器,将其放置在正确的位置,并使用正确的格式。坦白说,我在学习如何正确配置 Docker 镜像和凭据时遇到了很多困难。但一旦掌握了这个过程,查询和操作数据就会像平常一样。

这也是我最喜欢 Apache Spark 的其中一个原因——它将处理逻辑与连接逻辑分开。例如,如果我们将 blob 存储解决方案从 GCS 更改为 Amazon 的 S3,我们需要做的就是用新的 AWS 凭据重新配置环境并更改读写命令。所有的查询/转换逻辑保持不变。

但除了“没有那么特别”之外,学习如何与云存储组件交互是一个极其重要的技能,我希望这篇文章能帮助你更好地理解这个过程。

和往常一样,我不是帖子中涉及的任何主题的专家,我强烈建议进一步阅读,见下方参考文献。

感谢阅读!😉

参考文献

所有代码都可以在 这个 GitHub 仓库中找到

使用的数据 —* Microdados do Censo da Educação Superior,[CC BY-ND 3.0],INEP-巴西政府

[1] Chambers, B., & Zaharia, M. (2018). Spark: The definitive guide: Big data processing made simple. “ O’Reilly Media, Inc.”.

[2] 什么是 BigQuery?(无日期)。Google Cloud链接

[3] Delta Lake 官方页面。(无日期)。Delta Lake。 delta.io/

[4] Databricks. (2020 年 9 月 15 日)。利用 Delta Lake 改善 Apache SparkTM [视频]。YouTube。

[5]使用 BigQuery 连接器与 Spark。(无日期)。Google Cloud。 链接

[6] Sohail, K. (2021 年 12 月 15 日)。使用本地 PySpark 和 Jupyter Notebooks 从 Google Cloud Storage Bucket 读取文件。Medium。 链接

创建荷兰语问答机器学习模型

原文:towardsdatascience.com/creating-a-dutch-question-answering-machine-learning-model-3b666a115be3?source=collection_archive---------3-----------------------#2023-01-29

自然语言处理教程

使用自然语言处理翻译创建新的数据集

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Erwin van Crasbeek

·

关注 发布于 Towards Data Science ·20 min 阅读·2023 年 1 月 29 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

荷兰语问答模型创建流程

自然语言处理模型目前是一个热门话题。谷歌发布的《Attention Is All You Need》[1] 推动了许多像 BERT、GPT-3 和 ChatGPT 这样的 Transformer 模型的发展,这些模型受到了全球的广泛关注。虽然许多语言模型是在英语或多语言上进行训练的,但针对特定语言的模型和数据集可能难以找到或质量堪忧。

NLP 有广泛的应用,包括但不限于翻译、信息提取、摘要和问答,而后者是我个人一直在从事的工作。作为应用人工智能的学生,我一直在研究问答 NLP 模型,并且发现很难找到有用的荷兰语数据集用于训练。为了解决这个问题,我开发了一个翻译解决方案,可以应用于各种 NLP 问题和几乎所有语言,这可能对其他学生有兴趣。我认为这对人工智能开发和研究社区也具有很大的价值。特别是对于像问答这样的特定任务,几乎没有荷兰语数据集。通过翻译一个大型且知名的数据集,我能够以相对较低的努力创建一个荷兰语问答模型。

如果你有兴趣了解更多关于我的过程、我面临的挑战以及此解决方案的潜在应用,请继续阅读。本文旨在为具有基本 NLP 背景的学生提供。然而,我还为那些尚未熟悉该领域或仅需复习的人士提供了复习材料和各种概念的介绍。

为了正确解释我使用翻译数据集的解决方案,我将本文分为两个主要部分:数据集的翻译和问答模型的训练。我撰写本文的方式旨在展示我在解决方案方面的进展,同时也作为一个逐步指南。文章包括以下章节:

  1. 关于 NLP 的复习和 NLP 的简要历史

  2. 问题、数据集和问答

  3. 翻译数据集

  4. 构建一个问答模型

  5. 已取得的成就与未取得的成就?

  6. 未来计划

  7. 来源

关于 NLP 的复习和 NLP 的简要历史

为了更好地理解这个解决方案的各个元素,我想从对 NLP 及其近期历史的复习开始。我们所知道的语言可以分为两组,形式语言和自然语言。形式语言指的是专门为特定任务如数学和编程设计的语言。自然语言或普通语言是指由人类自然发展和演变的语言,没有任何形式的预先规划。这可以表现为我们所知道的各种人类言语形式,甚至是手语[2]。

NLP 在其最广泛的形式上是将计算方法应用于自然语言。通过将基于规则的语言建模与人工智能模型相结合,我们已经能够使计算机以一种能够处理文本和语音形式的方式“理解”我们的自然语言[3]。这种理解的方式——如果它真的可以称为理解的话——仍然存在争议。然而,像 ChatGPT 这样的最新发展表明,我们人类确实常常觉得这些模型的输出让人感到它有自我意识,并且具有较高的理解水平[4]。

当然,这种理解并非凭空而来。NLP 有着广泛的历史,可以追溯到二战后的 1940 年代[5]。在这个时期,人们意识到了翻译的重要性,并希望创造一种能够自动完成翻译的机器。然而,这证明是相当具有挑战性的。大约在 1960 年左右,NLP 研究分为基于规则的和随机的两大类。基于规则的或符号化的主要涉及形式语言和语法生成。这个领域的许多语言学研究者和计算机科学家认为这是人工智能研究的开始。随机研究则更多关注统计学和文本间的模式识别等问题。

自那时起,NLP(自然语言处理)领域取得了许多进展,研究领域也不断扩展。然而,NLP 模型生成的实际文本一直相当有限,且缺乏许多现实世界的应用。直到 2000 年代初期,NLP 的发展才迎来了每隔几年便有显著突破的阶段,这才导致了我们现在的情况。

问题、数据集和问答

现在我已经简要回顾了 NLP 的背景,是时候介绍我一直在研究的实际问题了。简而言之,我的目标是训练一个荷兰语问答的机器学习模型。然而,由于缺乏合适的数据集,这变得相当困难,因此我通过翻译创建了自己的数据集。在本文中,我将逐步讲解数据集的创建和机器学习模型的训练,以便你可以跟随并复制整个解决方案,或选择对你来说重要的部分。

本文可以分为两个主要部分。第一个是荷兰语数据集的创建,第二个是问答机器学习模型的训练。在这一章中,我将提供一些背景信息,介绍我的解决方案并解释我的选择。

数据集

如果我们想找到一个有用的荷兰语数据集,那么了解训练一个问答模型所需的具体内容是很重要的。生成答案的主要有两种方法:第一种是抽取式,第二种是生成式。

· 抽取式问答模型被训练以从上下文(源文本)中提取答案[7]。较早的方法通过训练一个模型来输出答案在上下文中的起始和结束索引来实现这一点。然而,Transformer 的引入使这种方法已经过时。

· 生成式问答模型被训练以根据上下文和问题生成新文本[8]。

图 1 展示了抽取式和生成式模型可能产生的输出示例。

尽管有不同的方法,但如今抽取式和生成式问答模型通常都基于像 BERT 这样的 Transformer[8],[9]。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图 1. 抽取式与生成式方式生成的答案示例。

基于关于抽取式和生成式模型的信息,我们现在知道我们需要一个包含上下文、问题、答案以及(可选的)答案在上下文中的起始和结束索引的数据集。我已经探索了以下选项,以寻找合适的数据集。

  • 我使用了 Cambazoglu et al 的 2020 年论文[10],以获得有关问答数据集的清晰图像。他们的研究结果提供了一张包含最显著问答数据集的表格。不幸的是,这些大型数据集中没有荷兰语的数据集。

  • 另一个选择是 Huggingface,它托管了大量的数据集[11]。乍一看,有一些荷兰语的问答数据集。然而,进一步检查显示,这些数据集往往不完整,包含网站域名而不是上下文,或者是各种语言的混合。这些数据集完全无法使用,或者不够完整,无法用于我们的目标。

从这些观察结果来看,几乎没有公共数据集可以用来训练荷兰语问答模型。手动创建我们自己的数据集将花费太多时间,那么我们还有什么其他选项?首先,我们可以简单地使用一个英语模型,将荷兰语输入翻译成英语,然后将输出再翻译回荷兰语。然而,通过 Google 翻译进行的快速测试表明,这种方法的结果远非理想,几乎感觉有些消极攻击。也许在双重翻译步骤中丢失了太多信息和上下文?这就引出了第二个选项,即翻译整个数据集并在其上进行训练。在我的研究中,我遇到了一些提到这一点的实例。例如,Zoumana Keita 在 Towardsdatascience 上的一篇文章[16]使用翻译进行数据增强。第三章将深入探讨我如何执行数据集的翻译。

最后,我们需要选择用于翻译的方法的数据集。既然我们决定翻译整个数据集,那么原始数据集使用什么语言就不重要了。斯坦福问答数据集(SQuAD)[12] 似乎相当受欢迎,并被 Paperswithcode 用于问答基准测试[13]。它还包含大量(100,000+)的问答,并且经仔细检查后似乎没有任何意外数据。这就是我们将要使用的数据集。

机器学习模型

现在我们已经确定了如何获取数据集;我们需要决定哪种机器学习模型适合回答问题的目标。在前一章中,我们已经确定可以选择抽取式模型和生成式模型。在我的研究中,我使用了生成式模型,因为它基于较新的技术,并且给出了更有趣的结果。然而,以防有人希望采用抽取式模型,我也会对此进行介绍。这也与数据集的选择一致,因为它包含了答案的起始索引。

从头开始训练一个 Transformer 模型,至少可以说是低效的。P. Azunre 的《自然语言处理中的迁移学习》一书[14]深入探讨了为什么进行迁移学习,并展示了如何进行迁移学习的多个示例。大量大型 NLP 模型托管在 Huggingface[15]上,并可用于迁移学习。我选择了 t5-v1_1-base 模型,因为它经过多语言的多任务训练。第四章将介绍该模型的迁移学习。

翻译数据集

在本章中,我将展示如何通过提供代码片段并对其进行解释来翻译数据集。这些代码块连续生成的代码就是我编写的整个数据集翻译脚本。欢迎跟随或取用对你有用的特定部分。

导入

解决方案使用了几个模块。首先,我们需要以尽可能快的速度翻译文本。在我的研究中,我尝试使用来自 Huggingface 的各种翻译 AI 模型,但迄今为止,最快的翻译器是 Googletrans 模块,它使用了 Google Translate API。该解决方案还使用了 httpx 的 Timeout 来定义翻译的超时时间,使用 json 解析 SQuAD 数据集,使用 Pandas 处理数据框,以及使用 Time 来测量所有操作所需的时间。

from googletrans import Translator, constants
from httpx import Timeout

import json
import pandas as pd
import time

初始化

首先,我们应该定义几个在脚本中会用到的常量。为了方便访问,我在这里添加了源语言和翻译语言。

Googletrans 模块为我们提供了一个可以自定义超时时间的翻译器。我使用了相对较长的超时时间,因为在测试期间翻译经常超时。我将在本指南后面的部分提供更多有关这个问题的信息。

src_lang = "en"
dest_lang = "nl"

translator = Translator(timeout = Timeout(60))

阅读 SQuAD 数据集

以下代码从训练和验证 json 文件中提取上下文、问题和答案。这是通过将文件以 json 格式读取,并以一种提取三种列表的方式遍历数据来完成的。对于每个问题和答案,上下文被复制并添加到上下文列表中。这样我们可以通过使用索引轻松访问带有相关上下文和答案的问题。

def read_squad(path):
    with open(path, 'rb') as f:
        squad_dict = json.load(f)
    contexts, questions, answers = [], [], []
    for group in squad_dict['data']:
        for passage in group['paragraphs']:
            context = passage['context']

            for qa in passage['qas']:
                question = qa['question']
                if 'plausible_answers' in qa.keys():
                    access = 'plausible_answers'
                else:
                    access = 'answers'
                for answer in qa[access]:
                    contexts.append(context)
                    questions.append(question)
                    answers.append(answer['text'])
    return contexts, questions, answers

train_c, train_q, train_a = read_squad('squad-train-v2.0.json')
val_c, val_q, val_a= read_squad('squad-dev-v2.0.json')

时间

以下代码为我们提供了每个翻译所需时间的大致估计。

def time_translation(entries, name):
    start_time = time.time()
    translation = translator.translate(entries[0], dest=dest_lang, src= src_lang)
    duration = time.time() - start_time
    total_duration = len(entries)*duration
    print(f"translating {name} takes {total_duration/60/60} hours")

time_translation(train_c, "train contexts")
time_translation(train_q, "train questions")
time_translation(train_a, "train answers")
time_translation(val_c, "validation contexts")
time_translation(val_q, "validation questions")
time_translation(val_a, "validation answers")

翻译

记得我提到过翻译超时的问题吗?在我的研究过程中,我不断遇到翻译超时的问题,导致结果数据集被损坏。事实证明,Googletrans 模块并不是 100% 可靠的,因为它使用了 Google Translate API。我找到的解决办法是创建一个小的包装函数,该函数会不断尝试翻译,直到成功为止。经过这样处理后,我不再遇到超时问题。

def get_translation(text):
    success = False
    translation = ""
    while not success:
        translation = translator.translate(text, dest=dest_lang, src=src_lang).text
        success = True
    return translation

由于我们从数据集中提取上下文的方式,每个问题和答案对都有重复的上下文。直接翻译所有上下文会显得冗余且非常缓慢,因此以下翻译函数首先会将前一个上下文与当前上下文进行比较。如果它们匹配,则使用之前的翻译。

def translate_context(contexts, name):
    start_time = time.time()
    context_current = ""
    translated_contexts = []
    index = 0

    for context in contexts:
        index+=1
        if context != context_current:
            context_current = context
            print(f"[{index}/{len(contexts)}]")
            get_translation(context)
            context_translated = get_translation(context)
            translated_contexts.append(context_translated)
        else:
            translated_contexts.append(context_translated)

    duration = time.time() - start_time
    print(f"Translating {name} took {round(duration, 2)}s") 
    return translated_contexts

翻译问题和答案非常简单,因为我们只需循环遍历列表来翻译所有内容。

def translate_qa(input, name):
    start_time = time.time()
    input_translated = []
    index = 0
    for text in input:
        text_nl = get_translation(text)
        input_translated.append(text_nl)
        index+=1
        print(f"[{index}/{len(input)}]")
    duration = time.time() - start_time
    print(f"Translating {name} took {round(duration, 2)}s") 
    return input_translated

现在我们可以使用我们定义的函数来翻译数据集的所有部分。

train_c_translated = translate_context(train_c, "train contexts")
train_q_translated = translate_qa(train_q, "train questions")
train_a_translated = translate_qa(train_a, "train answers")

val_c_translated = translate_context(val_c, "val contexts")
val_q_translated = translate_qa(val_q, "val questions")
val_a_translated = translate_qa(val_a, "val answers")

导出

只剩下将翻译导出以供以后使用。我们可以通过将列表转换为数据框,然后使用 to_csv 函数来完成这一点。需要注意的是,Googletrans 模块输出的翻译包含 utf-8 编码中不包含的字符。这就是我们在这里使用 utf-16 编码的原因。将其转换为 utf-8 可能在某些时候更有用,因为这可能对 AI 模型更有帮助。然而,由于我们这里只是在处理数据集,所以我们可以决定将这一步骤留到后续数据预处理阶段。

def save_data(data, name, header):
    data_df = pd.DataFrame(data)
    data_df.to_csv(name + "_pdcsv.csv", encoding='utf-16', index_label = "Index", header = [header])

save_data(train_c_translated, "train_contexts", "contexts")
save_data(train_q_translated, "train_questions", "questions")
save_data(train_a_translated, "train_answers", "answers")
save_data(val_c_translated, "val_contexts", "contexts")
save_data(val_q_translated, "val_questions", "questions")
save_data(val_a_translated, "val_answers", "answers")

构建问答模型

发现如何训练一个问答模型的过程有点挑战。然而,通过借鉴 P. Suraj [17] 的 Notebook,我能够创建一个基于 Transformer 的模型,该模型可以用于问答训练。按照 Notebook 的指导,我使用了 Torch 来创建模型。

导入

从导入开始,使用了以下模块。我们还定义了一些变量,这些变量定义了模型的最大输入和输出长度。

import pandas as pd
import unicodedata

import torch
from torch.utils.data import DataLoader

from transformers import T5Tokenizer
from transformers import T5ForConditionalGeneration
from transformers import AdamW
from tqdm import tqdm

from sklearn.model_selection import train_test_split 
from datetime import datetime

max_text_length = 512
max_output_length = 256

加载数据

现在我们可以加载之前创建的数据集。由于我们使用 Pandas 导出了 csv,因此现在可以轻松加载并将其转换为数组。我还定义了一个函数,该函数将在后续将任何训练或输入数据转换为 utf-8,这是我们将用于训练模型的格式。

def load_data(path):
    df = pd.read_csv(path, encoding='utf-16')
    df = df.drop('Index', axis=1)
    data = df.values.tolist()
    data = [a[0] for a in data]
    return data

def to_utf8(text):
    try:
        text = unicode(text, 'utf-8')
    except NameError:
        pass
    text = unicodedata.normalize('NFD', text).encode('ascii', 'ignore').decode("utf-8")
    return str(text)

现在我们可以实际加载数据。在模型训练中,我只使用了训练数据,并将其拆分为测试数据,测试数据大小为 0.2。

contexts_csv = 'train_contexts_pdcsv.csv'
questions_csv = 'train_questions_pdcsv.csv'
answers_csv = 'train_answers_pdcsv.csv'

contexts = load_data(contexts_csv)
questions = load_data(questions_csv)
answers = load_data(answers_csv)

c_train, c_val, q_train, q_val, a_train, a_val = train_test_split(contexts,
                                                questions, answers,
                                                test_size=0.2,
                                                random_state=42)

准备数据

如我之前提到的,我们可以训练一个抽取式模型和一个抽象生成模型。在我的研究中,我开发了这两种模型。在这篇文章中,我只介绍抽象生成版本,但对于感兴趣的读者,我还会解释如何为抽取式模型预处理数据。这是为了创建上下文中答案的起始和结束索引。

抽象生成

数据集不需要过多预处理就可以训练抽象生成模型。我们只需将所有训练数据转换为 utf-8。可以取消注释最后三行,以减少训练集的大小,这将改善训练时间并有助于调试。

def clean_data(contexts, questions, answers):
    cleaned_contexts, cleaned_questions, cleaned_answers = [], [], []
    for i in range(len(answers)):
        cleaned_contexts.append(to_utf8(contexts[i]))
        cleaned_questions.append(to_utf8(questions[i]))
        cleaned_answers.append(to_utf8(answers[i]))
    return cleaned_contexts, cleaned_questions, cleaned_answers

cc_train, cq_train, ca_train = clean_data(c_train, q_train, a_train); 
cc_val, cq_val, ca_val = clean_data(c_val, q_val, a_val); 

print("Original data size: " + str(len(q_train)))
print("Filtered data size: " + str(len(cq_train)))

#cc_train = cc_train[0:1000]
#cq_train = cq_train[0:1000]
#ca_train = ca_train[0:1000]

抽取式

在许多情况下,抽取式模型需要上下文中答案的起始和结束索引。然而,由于我们使用 Transformer 翻译了数据集,可能会出现一些问题。例如,答案可能与上下文中的措辞不同,或者答案的位置或长度可能已经改变。为了解决这个问题,我们可以尝试在上下文中找到答案,如果找到答案,则将其添加到清理后的答案中。因此,我们也获得了关于起始索引的信息,结束索引简单地是起始索引加上答案的长度。

def clean_data(contexts, questions, answers):
    cleaned_contexts, cleaned_questions, cleaned_answers = [], [], []
    for i in range(len(answers)):
        index = contexts[i].find(answers[i])
        if(index != -1):
        #print(str(index) + " + " + str(index+len(answers[i])))
            cleaned_contexts.append(contexts[i])
            cleaned_questions.append(questions[i])
            cleaned_answers.append({
                'text':answers[i],
                'answer_start': index,
                'answer_end': index+len(answers[i])
                })
    return cleaned_contexts, cleaned_questions, cleaned_answers

cc_train, cq_train, ca_train = clean_data(c_train, q_train, a_train); 
cc_val, cq_val, ca_val = clean_data(c_val, q_val, a_val);

分词器

下一步是分词,因为我们使用的是 t5-v1_1-base,我们可以直接从 Huggingface 导入分词器。然后,我们将使用问题对上下文进行分词,以便分词器将它们与结束字符串标记一起添加。我们还指定了之前定义的 max_text_length。最后,分词后的答案被添加到编码中作为目标。

tokenizer = T5Tokenizer.from_pretrained('google/t5-v1_1-base')
train_encodings = tokenizer(cc_train, cq_train, max_length=max_text_length, truncation=True, padding=True)
val_encodings = tokenizer(cc_val, cq_val, max_length=max_text_length, truncation=True, padding=True)

def add_token_positions(encodings, answers):
    tokenized = tokenizer(answers, truncation=True, padding=True)
    encodings.update({'target_ids': tokenized['input_ids'], 'target_attention_mask': tokenized['attention_mask']})

add_token_positions(train_encodings, ca_train)
add_token_positions(val_encodings, ca_val)

数据加载器

我们将使用数据加载器来训练 PyTorch 模型。这里还可以指定批量大小。我训练时的服务器内存有限,所以我不得不使用批量大小为 2。如果可能的话,使用更大的批量大小会更好。

class SquadDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
        print(encodings.keys())

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

train_dataset = SquadDataset(train_encodings)
val_dataset = SquadDataset(val_encodings)

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

训练模型

我们使用的模型是基于 T5-v1_1-base 的 T5ForConditionalGeneration。如果用于训练的 PC 或服务器上安装了 CUDA,我们可以尝试利用它来显著提高训练速度。我们还告诉模型我们将对其进行训练。

我们使用的优化器是 AdamW,学习率为 1e-4。这个选择基于 T5 文档[18],文档中提到在我们的情况下这是一个合适的值:

通常,1e-4 和 3e-4 对于大多数问题(分类、摘要、翻译、问答、问题生成)效果很好。

最后,我们定义一个函数,在模型训练完成后将其保存以供以后使用。

model = T5ForConditionalGeneration.from_pretrained('google/t5-v1_1-base')
cuda = torch.cuda.is_available()
device = torch.device('cuda') if cuda else torch.device('cpu')
model.to(device)
model.train()

optimizer = AdamW(model.parameters(), lr=1e-4)

def save_model():
    now = datetime.now()
    date_time = now.strftime(" %m %d %Y %H %M %S")
    torch.save(model.state_dict(), "answer_gen_models/nlpModel"+date_time+".pt")

模型的实际训练将在三个时期内完成,我使用的 Notebook [17] 和 T5 文档都表明这是一个不错的训练周期数。在我配备 RTX 3090 的 PC 上,这大约需要每个周期 24 小时。我使用的服务器利用了 Nvidia Tesla T4,每个周期大约需要 6 小时。

Tqdm 模块用于对训练状态进行可视化反馈。它提供了关于已过时间和估计训练时间的数据。两个注释箭头之间的步骤对于我们的问答目标很重要,这里我们定义了给模型的输入。该代码块中的其他步骤对于 PyTorch 模型的训练相当直接。

for epoch in range(3):
    loop = tqdm(train_loader, leave=True)
    for batch in loop:
        optim.zero_grad()

        # >
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        target_ids = batch['target_ids'].to(device)
        target_attention_mask = batch['target_attention_mask'].to(device)

        outputs = model(input_ids, attention_mask=attention_mask,
                        labels=target_ids,
                        decoder_attention_mask=target_attention_mask)
        # >
        loss = outputs[0]
        loss.backward()
        optimizer.step()

        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(loss=loss.item())
save_model()

结果

如果你跟随完成了,恭喜你!你已经创建了自己的荷兰数据集并训练了一个荷兰问答模型!如果你和我一样,可能迫不及待想尝试一下模型的结果。你可以使用以下代码来评估模型。有趣的是,你可能会发现模型不仅能够回答荷兰语问题!它也有能力回答不同(主要是日耳曼语)的语言的问题。这很可能是因为原始 T5-v1_1-base 模型已经在四种不同语言上进行了训练。

model = T5ForConditionalGeneration.from_pretrained('google/t5-v1_1-base')
model.load_state_dict(torch.load("answer_gen_models/some_model.pt"))

cuda = torch.cuda.is_available()
device = torch.device('cuda') if cuda else torch.device('cpu')
model.to(device)
model.eval()

def test(context, question):
    input = tokenizer([to_utf8(context)],
                      [to_utf8(question)],
                      max_length=max_text_length,
                      truncation=True,
                      padding=True)
    with torch.no_grad():
        input_ids = torch.tensor(input['input_ids']).to(device)
        attention_mask = torch.tensor(input['attention_mask']).to(device)
        out = model.generate(input_ids,
                             attention_mask=attention_mask,
                             max_length=max_output_length,
                             early_stopping=True)
        print([tokenizer.decode(ids,
        skip_special_tokens=True) for ids in out][0])

test("Dit is een voorbeeld", "Wat is dit?")

以下是一些示例背景和问题以及模型生成的答案:

背景 我们和应用人工智能硕士班的同学们去过科隆。

问题 班级去过哪里?

答案 科隆

背景 大棕色狐狸跳过了懒狗。

问题 狐狸跳过了什么?

答案 懒狗

背景 大棕色狐狸跳过了懒狗。

问题 狐狸做了什么?

答案 跳过懒狗

背景 两乘二是十。

问题 两乘二是多少?

答案

已经实现了什么,未实现什么?

总结一下,我们选择了一个用于问答的英文数据集,通过 Google Translate API 将其翻译成荷兰语,并训练了一个基于 T5-v1_1-base 的 PyTorch 编码器-解码器模型。我们究竟实现了什么,这在实际情况中是否能使用?

首先,重要的是要认识到我们没有对模型进行适当评估,因为这不是本文的范围。然而,为了能够正确解释我们的结果,并能够谈论其可用性,我建议查看如 Rouge [19] 等度量标准或进行人类评估。我采取的方法是人类评估。表 2 显示了五个人对各种上下文来源和问题生成答案的平均评分,评分范围从 1 到 5。平均分为 2.96\。这个数字本身并没有告诉我们很多信息,但我们可以从表中得出结论,我们创建的模型在某些情况下可以生成接近完美的答案。然而,它也经常生成评估小组认为完全无意义的答案。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

表 2. 各种文章、论文和学位论文的人类评估评分(1–5)。

还需要注意的是,通过翻译数据集,我们很可能引入了偏差。谷歌翻译背后的 AI 已经在一个数据集上进行训练,由于其基于自然语言,因此自然包含了偏差。通过翻译我们的数据,这种偏差将传递到任何使用该数据集训练的模型中。在像这样的数据集可以在实际情况下使用之前,应彻底评估,以指出其中的偏差以及这些偏差如何影响结果。

然而,这种解决方案对于那些在实验 AI、开发新型机器学习模型或仅仅学习 NLP 的人来说可能非常有趣。这是一种非常便捷的方式,可以为几乎任何 NLP 问题获取大规模的数据集。许多学生无法获得大数据集,因为这些数据集通常只对大型公司开放或费用过高。通过这样的方式,任何大型英语数据集都可以转换为特定语言的数据集。

未来计划

我个人对这个方法的应用前景非常感兴趣。我目前正在研究一个使用完全相同方法和数据集的问题生成模型。我希望调查这两种模型结合使用的效果,以便更多地了解潜在的偏差或错误。这与第五章讨论的评估需求是一致的。我已经通过请五个人对创建的模型的结果进行评分来创建了一个人类评估。然而,我打算进一步了解不同的度量标准,这些标准可以更好地告诉我模型的工作原理、生成某些结果的原因以及其中包含的偏差。

我还了解到,斯坦福问题与回答数据集的 2.0 版本包含一些无法回答的问题。虽然这与本文提供的解决方案没有直接关系,但我对将本文的解决方案应用于完整的 SQuAD 2.0 数据集后的结果差异感到好奇。

来源

[1] A. Vaswani et al.,“注意力机制是你所需要的一切,” 2017 年。

[2] D. Khurana, A. Koli, K. Khatter, 和 S. Singh,“自然语言处理:最新进展、当前趋势和挑战,” Multimedia Tools and Applications,2022 年 7 月,doi: 10.1007/s11042–022–13428–4。

[3] “什么是自然语言处理?| IBM,” www.ibm.comwww.ibm.com/topics/natural-language-processing(访问日期:2023 年 1 月 11 日)。

[4] E. Holloway,“是的,ChatGPT 是有意识的 — 因为实际上是人类在其中,” Mind Matters,2022 年 12 月 26 日。mindmatters.ai/2022/12/yes-chatgpt-is-sentient-because-its-really-humans-in-the-loop/(访问日期:2023 年 1 月 18 日)。

[5] “NLP — 概述,” cs.stanford.educs.stanford.edu/people/eroberts/courses/soco/projects/2004-05/nlp/overview_history.html(访问日期:2023 年 1 月 18 日)。

[6] S. Ruder,“自然语言处理最近历史的回顾,” Sebastian Ruder,2018 年 10 月 1 日。ruder.io/a-review-of-the-recent-history-of-nlp/(访问日期:2023 年 1 月 18 日)。

[7] S. Varanasi, S. Amin, 和 G. Neumann,“AutoEQA:用于提取式问答的自动编码问题,” 计算语言学协会年会论文集:EMNLP 2021,2021 年。

[8] “什么是问答? — Hugging Face,” huggingface.cohuggingface.co/tasks/question-answering(访问日期:2023 年 1 月 18 日)。

[9] R. E. López Condori 和 T. A. Salgueiro Pardo,“观点总结方法:比较和扩展提取式和抽象式方法,” 专家系统应用,第 78 卷,第 124–134 页,2017 年 7 月,doi: 10.1016/j.eswa.2017.02.006。

[10] B. B. Cambazoglu, M. Sanderson, F. Scholer, 和 B. Croft,“关于问答研究的公共数据集综述,” ACM SIGIR Forum,第 54 卷,第 2 期,第 1–23 页,2020 年 12 月,doi: 10.1145/3483382.3483389。

[11] “Hugging Face — 建设未来的人工智能社区,” huggingface.cohuggingface.co/datasets?language=language:nl&task_categories=task_categories:question-answering&sort=downloads(访问日期:2023 年 1 月 18 日)。

[12] “斯坦福问答数据集,” rajpurkar.github.iorajpurkar.github.io/SQuAD-explorer/(访问日期:2023 年 1 月 18 日)。

[13] “Papers with Code — 问答,” paperswithcode.compaperswithcode.com/task/question-answering(访问日期:2023 年 1 月 18 日)。

[14] P. Azunre,自然语言处理中的迁移学习。Simon and Schuster,2021 年。

[15] “Hugging Face — 一次提交解决 NLP 问题的使命。” huggingface.cohuggingface.co/models(访问日期:2023 年 1 月 18 日)。

[16] Z. Keita,“使用 MarianMT 进行 NLP 中的数据增强,” Medium,2022 年 11 月 5 日。 towardsdatascience.com/data-augmentation-in-nlp-using-back-translation-with-marianmt-a8939dfea50a(访问日期:2023 年 1 月 18 日)。

[17] P. Suraj,“Google Colaboratory,” colab.research.google.comcolab.research.google.com/github/patil-suraj/exploring-T5/blob/master/T5_on_TPU.ipynb(访问日期:2023 年 1 月 25 日)。

[18] “T5,” huggingface.cohuggingface.co/docs/transformers/model_doc/t5#transformers.T5Model(访问日期:2023 年 1 月 25 日)。

[19] “ROUGE — evaluate-metric 提供的 Hugging Face 空间,” huggingface.cohuggingface.co/spaces/evaluate-metric/rouge(访问日期:2023 年 1 月 25 日)。

除非另有说明,所有图片均为作者所摄。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值