使用 PostgreSQL、FastAPI 和 Docker 构建一个后端
原文:
towardsdatascience.com/build-a-back-end-with-postgresql-fastapi-and-docker-7ebfe59e4f06
开发基于地图的应用程序的逐步指南(第四部分)
·发表于 Towards Data Science ·阅读时间 28 分钟·2023 年 3 月 14 日
–
图片来自 Caspar Camille Rubin 于 Unsplash
地图是可视化和理解地理数据的强大工具,但需要特定的技能才能高效设计。
在这个逐步指南中,我们将深入探讨如何构建一个基于地图的应用程序,以展示客户周围加油站的价格。我们将涵盖产品的不同关键步骤,从最初的概念验证(POC)到最小可行产品(MVP)。
系列文章:
第一部分:概念验证——构建一个简约的演示
第二部分:如何使用 React 构建网页应用(静态布局)
第三部分:使用 React 为你的网页应用添加互动性
第四部分:使用 PostgreSQL、FastAPI 和 Docker 构建后端
关于这篇文章的一些背景信息
在系列文章的前几部分,我们使用 React 构建了加油站查找器的前端,并将后端视为一个仅提供相关数据的“黑箱”。
在这一部分,我们将详细介绍如何使用强大的工具,如 PostgreSQL 或 FastAPI,一步一步构建后端。
你可以在我的 Github 页面 找到该项目的完整代码。
为什么我们需要一个干净的后端?
在本系列的第一部分中,我们创建了一些实用函数,以便从公共提供商直接获取燃料站的数据。虽然这对于我们的概念验证足够了,但由于多种原因,我们现在需要一个更强大的系统:
-
性能与延迟:实时处理数据,包括解析 XML、格式化和过滤,计算成本高,对于预期频繁使用的应用程序来说,可能不切实际。
-
可靠性:确保我们的应用程序不会受到第三方数据源的意外更改或停机的影响。仅依赖外部门户的数据会使我们的应用程序面临风险,因为即使提供商简单地更改字段名称也可能导致我们这边的错误和停机,而我们需要修补这些变化。通过构建我们自己的数据库,我们可以对数据有更大的控制权,进行必要的更新和维护,而无需依赖外部方。
-
自定义:通过我们自己的数据库,我们可以定制数据以满足技术规格,添加其他外部数据源,为不同的用例构建自定义数据视图等……
为了满足这些需求,我们将构建自己的数据库和 API,处理数据获取、处理和前端交付。这将包括使用 Docker 运行 PostgreSQL 数据库,使用 Python 和 sqlalchemy 与数据库交互,并使用 PostGIS 扩展进行地理查询。我们还将探索如何使用 FastAPI 和 SQLmodel 构建一个简单的 API。
下图是应用程序不同组件的简单示意图:
我们应用程序不同组件的简单视图,作者插图
本文涵盖的内容
在本文中,我们专注于内部数据库和 API 的创建。具体来说,我们将:
-
使用 Docker 运行 PostgreSQL 数据库
-
使用 python 和 sqlalchemy 与数据库交互
-
使用 PostGIS 扩展进行地理查询
-
使用 FastAPI 和 SQLmodel 构建一个简单的 API
-
使用 docker-compose 容器化我们的项目并运行
使用 Docker 创建本地 PostgreSQL 实例
Docker 是一个开源容器化平台,允许你在一致且隔离的环境中运行应用程序。使用 Docker 设置 PostgreSQL 服务器有几个优点,包括能够以标准化的方式安装应用程序,而不必担心与系统上其他配置冲突。
在我们的案例中,我们将直接在容器内部设置 Postgre 服务器。
我在这里假设你已经在计算机上安装了 Docker,因为安装方法因系统而异。
获取容器镜像
Docker 镜像可以看作是构建一个专用于特定任务的容器所需的所有规格。它本身不做任何事情,但用于构建容器(一个专用虚拟环境),你的应用程序将运行在其中。我们可以使用 Dockerfile 创建我们自己的自定义镜像(稍后会讨论),或者我们可以从社区共享的各种开源镜像中下载现成的镜像。
在我们的情况下,我们需要一个可以帮助我们创建运行 PostgreSQL 的容器的镜像,我们可以使用 官方镜像 来实现这一目的。
我们从在 Docker 上下载 PostgreSQL 镜像开始。这是在 Shell 中完成的:
docker pull postgres
运行 Postgre 容器
一旦镜像在 Docker 中下载完成,我们可以使用以下命令基于它构建容器:
docker run -itd -e POSTGRES_USER=jkaub -e POSTGRES_PASSWORD=jkaub -p 5432:5432 -v ~/db:/var/lib/postgresql/data --name station-db postgres
让我们解密它。
-itd 是三个参数的组合:
-
-d 表示我们以分离模式运行容器。在这种模式下,容器将在后台运行,我们可以继续使用终端进行其他操作。
-
-i 指定我们的容器将以交互模式运行。这将允许我们进入容器并与之交互。
-
-t 表示容器内部将提供一个伪终端,以便与容器进行交互,这将带来更无缝和直观的交互体验。
-e 用于在容器内部生成环境变量。在这种情况下,环境变量 POSTGRES_USER 和 POSTGRES_PASSWORD 还用于用给定的密码生成我们的 PostgreSQL 实例的新用户。如果没有这个,我们仍然可以使用默认用户/密码(postgre/postgre)访问 PostgreSQL 实例。
-p 用于将本地机器的端口映射到 Docker 容器中。PostgreSQL 的默认端口是 5432。如果它已经在你的本地机器上被使用,你可以使用这个参数将容器中的 5432 映射到你机器上的另一个端口。
-v 是一个在我们这种情况下非常重要的参数:它允许我们将一个卷从我们的机器(在我们这种情况下是文件夹 ~/db
)映射到容器内部的卷,其中 SQL 数据默认存储在 /var/lib/postgresql/data
。通过进行这种映射,我们创建了一个持久卷,即使容器停止后也会保留。因此,我们的数据库将持续存在,即使我们停止使用容器,之后也可以使用。
— name 只是一个标志,用于给容器命名,这将有助于我们以后访问它。
我们可以通过使用下面的命令检查容器是否处于活动状态,该命令将显示我们机器上运行的容器列表:
docker ps
返回:
CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES
cb0840806636 postgres "docker-entrypoint.s…" 2 minutes ago Up 2minutes 0.0.0.0:5432->5432/tcp station-db
与 PostgreSQL 的初步交互
我们的 PostgreSQL 实例现在正在容器中运行,我们可以与它进行交互。
创建数据库
作为起点,让我们创建一个包含项目不同表格的第一个数据库。
为此,我们需要进入容器。记住,这可能是因为我们在初始化容器时指定了 -it 参数。下面的命令行将完成这项工作:
docker exec -it station-db bash
命令提示符现在应该是:
root@cb0840806636:/#
这意味着我们以 root 用户身份登录到容器中。我们可以使用用户 (-U)/密码 (-d) 连接到 PostgreSQL,如下所示:
psql -U jkaub -d jkaub
一旦进入 PostgreSQL 实例,我们可以使用 SQL 查询与之交互,特别是创建一个新的数据库来托管我们未来的表。
CREATE DATABASE stations;
我们可以通过运行来验证数据库是否已经创建
\l
这将显示系统中的不同数据库。在一些在实例初始化时创建的默认数据库中,我们可以找到刚刚创建的那个:
jkaub=# \l
List of databases
Name | Owner | Encoding | Collate | Ctype | ICU Locale | Locale Provider | Access privileges
-----------+-------+----------+------------+------------+------------+-----------------+-------------------
stations | jkaub | UTF8 | en_US.utf8 | en_US.utf8 | | libc |
现在我们已经设置好了 PostgreSQL 实例,我们可以通过手动在psql中编写 SQL 查询来创建表和从 .csv 文件导入数据。虽然这种方法适用于一次性使用,但如果我们需要频繁更新表,它可能会变得繁琐且容易出错。
因此,为了促进自动化,我们将使用 Python 框架与数据库及其表进行交互。这将允许我们通过代码轻松创建、更新和查询数据库,使过程更加高效且减少错误。
用 sqlalchemy 打开一个会话
SQLalchemy 是一个开源 SQL 工具包和对象关系映射(ORM)工具,供 Python 开发者使用。它提供了一组高级函数来与数据库交互,而不是编写 SQL 查询。
这特别方便,因为它允许我们使用 Python 类(在这里也称为“模型”)定义表的结构,并使用面向对象的范式。我们的 Python ORM,sqlalchemy,在下一部分构建后端 API 时将特别有用。
让我们开始安装项目所需的库。在 sqlalchemy 的基础上,我们还将使用psycopg2,这是一个 PostgreSQL 适配器,可被 sqlalchemy 用作连接器。
pip install psycopg2 sqlalchemy
我们现在可以直接在 Python 中有效地创建一个会话来访问我们的数据库:
from sqlalchemy import create_engine
engine = create_engine('postgresql://jkaub:jkaub@localhost/stations')
# test the connection by executing a simple query
with engine.connect() as conn:
result = conn.execute('SELECT 1')
print(result.fetchone())
一步步解释这个脚本:
engine = create_engine('postgresql://jkaub:jkaub@localhost/stations')
create_engine 方法用于保持与数据库的连接。我们需要在这里指定一个数据库 URL,该 URL 包含连接到我们数据库所需的所有信息。
-
该 URL 的第一部分 postgresql**😕/** 是为了指定我们正在使用 PostgreSQL 连接,并且接下来将是该类型数据库连接的规格。如果你使用的是不同的数据库,如 SQLite,你将会有不同的基本 URL 和规格。
-
jkaub:jkaub 是连接到我们数据库的登录信息。
-
localhost 是运行数据库的服务器。服务器 IP 也可以用来连接远程服务器,或者,如我们稍后会看到的,在容器集群的情况下,我们在某些情况下也可以使用容器名称。
-
/stations 用于指定我们想要连接的数据库。在我们的例子中,我们连接到我们刚刚创建的“stations”。
# test the connection by executing a simple query
with engine.connect() as conn:
result = conn.execute('SELECT 1')
print(result.fetchone())
这部分代码目前仅用于测试连接是否成功。我们的数据库还没有表可以查询,所以我们只是运行一个虚拟查询。它应该返回 (1,),这意味着连接成功。
使用 FastAPI 构建 API
现在我们已经在 Docker 容器中设置了 PostgreSQL 数据库,并使用 SQLAlchemy 引擎访问了它,现在是时候开发与数据库交互的 API 了。
使用 API 有几个好处:
-
它提供了可重用性和平台/语言的独立性,允许多个服务使用相同的 API 端点。
-
它将数据库逻辑与应用逻辑分开,使得只要输入/输出被尊重,修改一个而不影响另一个变得更加容易。
-
它增加了一层安全性,因为你可以通过授权系统控制谁可以访问数据库。
-
最后,API 是可扩展的,可以在多个服务器上运行,使其在管理工作负载方面非常灵活。通过创建一组明确的 URL,我们将能够通过 API 从数据库中检索、修改、插入或删除数据。
关于 FastAPI
FastAPI 是一个现代的 Python 框架,在构建轻量级 API 方面特别高效,由 Sebastián Ramírez 开发。
结合sqlalchemy和pydantic这两个用于数据验证的 Python 库时,它特别高效(例如,它可以检查一个日期是否确实是日期,一个数字是否确实是数字,等等)。一起使用,它使我们能够通过框架有效地处理和查询表格。
更棒的是,Sebastián Ramírez 还设计了另一个库,sqlmodel,它结合了 pydantic 和 sqlalchemy,去除了一些冗余,并进一步简化了 API 的架构。
如果你还不熟悉 FastAPI,我建议你先看看这个教程,这个教程做得非常好。
在开始项目之前,我们需要安装多个库。
pip install uvicorn
pip install fastapi
pip install sqlmodel
pip install geoalchemy2
-
uvicorn 是运行 API 服务器的工具,并且非常适合与 FastAPI 一起使用。
-
fastapi 是 API 的核心引擎,我们将用它来创建不同的端点。
-
sqlmodel 将 sqlalchemy ORM 与 pydantic 的类型验证功能结合起来
-
geolochemy2 是 sqlalchemy 的一个扩展,用于执行地理查询。
初始化模型
让我们为我们的 API 项目创建一个新的仓库,从使用sqlmodel定义模型开始。 “模型”只是一个代表 SQL 中表格的 Python 类。
api/
|-- app/
|-- __init__.py
|-- models.py
我们的项目将包含 3 个表格,并遵循我们在 part I 中构建的初始设计。
-
一个包含与城市相关的信息(邮政编码、位置)的表格
-
一个包含关于燃气价格信息的表格
-
一个包含关于车站信息的表格
通过联接和地理过滤器组合这些表,将帮助我们构建前端请求的最终输出。
我们来看一下第一个表格,Cities 表:
from sqlmodel import Field, SQLModel
from datetime import datetime
from typing import Optional
class Cities(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
postal_code: str
name: str
lat: float
lon: float
类 “Cities” 继承自 SQLModel 类,结合了 sqlalchemy**’s** ORM 特性和 pydantic 的类型控制。
参数 table=True 表示如果数据库中尚不存在相应的表,则会自动创建该表,并匹配列名和列类型。
类的每个属性将定义每一列及其类型。特别是,“id”将作为我们的主键。使用 Optional 将指示 sqlalchemy 如果我们不提供 id,则自动生成 id。
我们还提供了另外两个表的模型:
from datetime import datetime
...
class GasPrices(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
station_id: str
oil_id: str
nom: str
valeur: float
maj: datetime = Field(default_factory=datetime.utcnow)
class Stations(SQLModel, table=True):
station_id: str = Field(primary_key=True)
latitude: float
longitude: float
cp: str
city: str
adress: str
注意在 Stations 表中,我们使用 station_id 作为主键,与 GasPrices 不同,该字段是必填的。如果在发送到表格时该字段为空,将会生成错误信息。
初始化引擎
在另一个专用文件中,为了保持项目结构化,我们将初始化引擎。我们称之为 services.py。
api/
|-- app/
|-- __init__.py
|-- models.py
|-- services.py
连接到数据库的方式与之前介绍的相同。
from sqlmodel import SQLModel, create_engine
import models
DATABASE_URL = 'postgresql://jkaub:jkaub@localhost/stations'
engine = create_engine(DATABASE_URL)
def create_db_and_tables():
SQLModel.metadata.create_all(engine)
注意函数 create_db_and_tables():该函数将在 API 初始化期间被调用,查看 models.py 中定义的模型,并直接在 SQL 数据库中创建它们(如果它们尚不存在的话)。
实操 API
我们现在可以开始开发主要组件,我们将在其中放置端点(即允许我们与数据库交互的 URL)。
api/
|-- app/
|-- __init__.py
|-- main.py
|-- models.py
|-- services.py
我们要做的第一件事是配置 FastAPI 在启动时的设置,并处理 API 授权。
from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI, HTTPException
from models import Cities, Stations, GasPrices
from services import engine, create_db_and_tables
#We create an instance of FastAPI
app = FastAPI()
#We define authorizations for middleware components
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
#We use a callback to trigger the creation of the table if they don't exist yet
#When the API is starting
@app.on_event("startup")
def on_startup():
create_db_and_tables()
一个重要的点是:默认情况下,我们的前端没有访问 API 调用的权限,如果你忘记配置中间件部分,将会在前端出现错误。你可以决定通过使用以下方式允许所有来源:
allow_origins=["*"],
但由于安全原因,我不推荐这样做,因为一旦上线,你基本上会将 API 向全世界开放。我们的前端目前在 localhost:3000 上本地运行,所以这是我们将允许的域名。
到那时,我们已经可以通过使用以下命令行来启动 API:
uvicorn main:app --reload
— reload 只是意味着每次在 API 正在运行时保存修改,它将重新加载以包含这些更改。
一旦启动,你可以看到一些日志显示在终端中,特别是:
INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)
表示 API 服务器正在 localhost(等同于 IP 127.0.0.1)上的 8000 端口运行。
如前所述,启动 API 还会触发数据库中空表的创建(如果它们尚不存在)。因此,从你第一次启动 API 的那一刻起,使用 table=True 创建的模型将在数据库中有一个专用表。
我们可以通过在 psql 中从 PostgresSQL 容器内部轻松检查这一点。连接为主用户后,我们首先连接到数据库 station:
\c stations
我们现在可以检查我们的表是否已正确创建:
\dt
这将返回:
List of relations
Schema | Name | Type | Owner
--------+-----------------+-------+-------
public | cities | table | jkaub
public | gasprices | table | jkaub
public | stations | table | jkaub
我们还可以通过在 psql 中运行描述查询来验证列是否与我们的模型匹配,例如,对于 cities 表:
\d cities
Column | Type | Collation | Nullable | Default
-------------+-------------------+-----------+----------+-----------------------
-------------
id | integer | | not null | nextval('cities_id_seq
'::regclass)
postal_code | character varying | | not null |
name | character varying | | not null |
lat | double precision | | not null |
lon | double precision | | not null |
Indexes:
"cities_pkey" PRIMARY KEY, btree (id)
构建我们的第一个请求 — 使用 POST 请求向 Cities 添加行
Cities 表只会填充一次,并用于将邮政编码与城市的经纬度匹配,这在后续使用邮政编码查询这些位置时特别有帮助。
目前,数据存储在 .csv 文件中,我们想设计一个 POST 调用,用于更新表格,如果数据还不在数据库中,则一次添加一行。API 调用被放在 main.py 文件中。
from sqlmodel import Session
...
@app.post("/add-city/")
def add_city(city: Cities):
with Session(engine) as session:
session.add(city)
session.commit()
session.refresh(city)
return city
让我们逐行查看这段代码:
@app.post("/add-city/")
每个 API 端点都是通过装饰器定义的。我们在这里定义了两件事:请求的类型(get、post、put、delete…)和关联的 URL 端点 (/add-city/)。
在这个特定的案例中,我们将能够在 http://127.0.0.1:8000/add-city/ 执行 POST 请求
def add_city(city: Cities):
我们将不同的参数传递给函数以用于查询。在我们的案例中,post 请求将寻找 Cities 的实例,这将通过 JSON 传递在我们的请求中。此 JSON 将包含我们想要添加的新行的 Cities 表的每列的值。
with Session(engine) as session:
要连接到数据库,我们打开一个 Session。每个查询需要自己的会话。使用这种方法在会话内部出现意外情况时特别有用:在会话初始化和 commit() 之间所做的所有更改将在出现问题时回滚。
session.add(city)
session.commit()
session.refresh(city)
在这里,对象被添加到数据库中,然后进行提交。从提交的那一刻起,操作无法回滚。refresh 用于更新 DB 进行的任何修改的“city”对象。在我们的案例中,例如,会自动添加递增的“id”。
return city
我们通过以 JSON 格式发送 city 对象来结束请求。
我们现在可以在 python 中尝试请求(当然需要 API 正在运行):
import requests
url='http://127.0.0.1:8000/add-city/'
json = {
'postal_code': '01400',
'name':"L'Abergement-Clémenciat",
'lat':46.1517,
'lon':4.9306
}
req = requests.post(url, json=json)
请注意,我们在请求中发送的 JSON 键与我们要更新的表的列名称匹配。参数“id”是可选的,它将自动添加到操作中,我们不需要担心它。
这应该会在 API shell 中触发以下行:
INFO: 127.0.0.1:33960 - "POST /add-gas-price/ HTTP/1.1" 200 OK
这意味着请求成功。我们可以进一步验证该行是否已正确添加。返回到我们 docker 中的 psql,我们可以尝试以下查询:
SELECT * FROM cities LIMIT 1;
这将显示:
id | postal_code | name | lat | lon
----+-------------+-------------------------+---------+--------
1 | 01400 | L'Abergement-Clémenciat | 46.1517 | 4.9306
演示该行已被 API 有效添加到我们的数据库中。
此外,我们不希望邮政编码被重复添加。为此,我们将查询 Cities 表,根据我们尝试发送的邮政编码过滤表,并在找到具有该邮政编码的行时返回 HTML 错误,从而避免邮政编码重复。
from fastapi import FastAPI, HTTPException
...
@app.post("/add-city/")
async def add_city(city: Cities):
with Session(engine) as session:
#New code block
exist = session.query(Cities).filter(
Cities.postal_code == city.postal_code).first()
if exist:
raise HTTPException(
status_code=400, detail="Postal code already exists")
#New code block
session.add(city)
session.commit()
session.refresh(city)
return city
在这段新代码中,我们使用 sqlalchemy ORM 执行了第一次数据库查询:我们没有编写经典的 SQL(“SELECT FROM”),而是使用了一组函数来直接查询数据库。
exist = session.query(Cities).filter(
Cities.postal_code == city.postal_code).first()
-
.query
相当于 SELECT … FROM …,在我们的案例中,我们从 cities 表中选择所有内容。 -
.filter
相当于 WHERE 语句。特别地,我们要匹配与我们发送的对象的邮政编码相等的条目(由变量 city 表示)。 -
.first()
是自解释的,相当于 LIMIT 1。 -
如果没有找到行,则
exist
将为 None,并且不会引发异常,因此我们将把对象添加到数据库中。如果某行与邮政编码匹配,API 请求将返回状态码 400 的错误,且 POST 请求将被中断,元素不会被添加。
如果我们现在尝试发送完全相同的请求,将看到 API 返回错误消息:
INFO: 127.0.0.1:49076 - "POST /add-city/ HTTP/1.1" 400 Bad Request
而且该行没有被添加到表中。
从那时起,我们可以简单地遍历 .csv
并逐一添加所有城市,以填充 cities 表。
使用 POST 请求向 Gasprices 和 Stations 表中添加行
我们将非常快速地跳过这些 API 调用的构建,因为它们与之前的非常相似。
@app.post("/add-station/")
async def add_station(station: Stations):
with Session(engine) as session:
exist = session.query(Stations).filter(
Stations.station_id == station.station_id).first()
if exist:
raise HTTPException(
status_code=400, detail="Station already exists")
session.add(station)
session.commit()
session.refresh(station)
return station
@app.post("/add-gas-price/")
async def add_station(gasPrice: GasPrices):
with Session(engine) as session:
exist = session.query(GasPrices). \
filter(GasPrices.oil_id == gasPrice.oil_id). \
filter(GasPrices.maj == gasPrice.maj). \
first()
if exist:
raise HTTPException(
status_code=400, detail="Entry already exists")
session.add(gasPrice)
session.commit()
session.refresh(gasPrice)
return gasPrice
唯一需要注意的有趣之处在于,我们使用了双重过滤查询,以确保仅在 oil_id
有新更新时才添加一行。这样,我们确保了未来的更新不会在价格没有变化的情况下创建重复项,从而节省了数据库空间。
为了检索和处理油价,我们只是回收了来自 Part I 的解析代码,获取相应的数据集并循环遍历,为每个条目进行 POST 调用。
以下脚本在 API 范围之外执行,以将数据上传到数据库:
import request
from data_parsing import get_data
BASE_API_URL = 'http://127.0.0.1:8000'
#get_data is the function designed in part I to pull the xml from the opendata
#source and convert them in Dataframes
stations, gas = get_data()
#Pushing stations data
to_push = stations[['latitude','longitude','cp','adress','city','station_id']].to_dict('records')
url=f'{BASE_API_URL}/add-station/'
for elmt in to_push:
req = requests.post(url, json=elmt)
#Pushing gasprices data
to_push = gas.to_dict('records')
url=f'{BASE_API_URL}/add-gas-price/'
for elmt in to_push:
req = requests.post(url, json=elmt)
注意:为了简化起见,我在这里选择逐行推送数据。我们也可以设计端点以批量推送数据,并发送 JSON 列表。
构建前端使用的 GET 查询
到这一点为止,我们的数据库已完全填充,以上脚本可以用于用更近期的数据更新数据库,我们可以开始构建前端用来查询特定城市周围加油站特定燃料价格的 GET 请求。
我决定为这个特定查询专门设置一个完整的章节,因为它的复杂性(我们将使用到目前为止定义的所有表,进行连接和地理过滤),同时因为我们需要在此时对数据库进行一些更改,以集成空间功能,安装附加组件和修改一些模型。虽然这可以从一开始就直接完成,但在实际项目中进行修改是常见的,我认为展示如何平滑地完成这一过程是很有趣的。
安装 PostGIS
PostGIS 是 PostgreSQL 的一个扩展,允许我们构建地理查询,这意味着需要一个空间组件。例如,在我们的情况下,我们能够选择距离某个兴趣点 30 公里半径内的所有车站数据行。
现在我们不想直接在运行的容器中安装 PostGIS,因为每次我们需要弹出一个新的容器时,这个安装会“丢失”,新容器基于仅安装 PostgreSQL 的镜像。
相反,我们将简单地更改用于构建容器的镜像,并将其替换为包含 PostgreSQL 和 PostGIS 的镜像。我们将提供相同的持久存储位置,以便新容器也可以访问它。
要构建包含 PostGIS 扩展的容器,我们首先从 docker 中拉取最新的 PostGIS 镜像,然后杀死并删除当前的 PostgreSQL 容器,并用新的镜像构建一个新的容器。
docker pull postgis
docker kill stations
docker rm stations
docker run -itd -e POSTGRES_USER=jkaub -e POSTGRES_PASSWORD=jkaub -p 5432:5432 -v ~/db:/var/lib/postgresql/data --name station-db postgis/postgis:latest
我们可以像之前一样访问容器,但现在我们使用的是包括 PostGIS 的 PostgreSQL 版本。
我们现在需要将扩展添加到现有数据库中。我们首先重新连接到数据库:
docker exec -it station-db bash
psql -U jkaub -d jkaub
\c stations
然后我们在其中包含 PostGIS 扩展:
CREATE EXTENSION postgis;
修改我们的 Stations 模型
现在我们的数据库中已经运行了 PostGIS,我们需要修改我们的 Stations 表以便能够执行地理查询。更准确地说,我们需要添加一个“geometry”字段,该字段被理解并转换为地球上的实际位置。
建立地图或标示地球上的位置有多种方法,每种方法都有其自己的投影和参考坐标系统。为了确保一个系统能够与另一个系统对话,我们需要确保它们使用相同的语言,这可能包括单位的转换(就像我们可以将米转换为英尺,或将千克转换为磅)。
对于坐标,我们使用称为“地理参数数据集”(EPSG)的东西。纬度和经度(EPSG 4326)以角度表示,无法直接转换为距离(欧几里得几何,包括距离计算,不能直接应用于球面,因为本质上这不是一个欧几里得表面)。相反,需要将其投影到平面表示中,这在 PostGIS 中处理得很好,只要我们注意并应用适当的转换。
作为起点,我们需要在 Stations 数据库中添加一个可以被解释为“几何”坐标的新字段。在我们的数据库中:
ALTER TABLE stations ADD COLUMN geom geometry(Point, 4326);
这一行将修改我们的 stations 表,添加一个新的字段 “geom”,它是类型为 “point” 的 PostGIS 几何图形,使用 EPSG 4326(经纬度系统的 EPSG)表示。该字段目前对所有行为空,但我们仍然可以在 SQL 中轻松填充它,以更新当前表(此时表并不为空)。
UPDATE stations SET geom = ST_SetSRID(ST_MakePoint(longitude, latitude), 4326);
上面的 SQL 查询将为 Stations 表的每一行设置 geom 列,使用由经度/纬度构建的点。注意我们在这里使用了两个 PostGIS 函数,ST_MakePoint 和 ST_SetSRID,来帮助我们在 SQL 中定义几何图形。
我们可以检查这个新的几何图形在数据库中的存储方式
SELECT * FROM stations LIMIT 1;
station_id | latitude | longitude | cp | city | adress | geom
------------+----------+-----------+-------+-------+-----------------------+----------------------------------------------------
26110004 | 44.36 | 5.127 | 26110 | NYONS | 31 Avenue de Venterol | 0101000020E6100000355EBA490C821440AE47E17A142E4640
你可以在这里看到,几何图形以字符串编码,这种格式是 Well-Known Binary (WKB),它在存储几何图形时非常高效。我不会进一步展开这个内容,但如果你在数据集中看到这个格式不要感到惊讶,如果需要,你可能需要将其解码为更可读的格式。
现在,我们还需要更新 model.py 文件中的 Stations 类以包含这个新字段,为此我们使用 geoalchemy 的 “Geometry” 类型。
from typing import Any
from geoalchemy2.types import Geometry
class Stations(SQLModel, table=True):
station_id: str = Field(primary_key=True)
latitude: float
longitude: float
cp: str
city: str
adress: str
geom: Optional[Any] = Field(sa_column=Column(Geometry('GEOMETRY')))
最后的修改是:我们希望在 POST 调用(在 main.py 中)时,使用纬度和经度参数自动计算几何图形:
from geoalchemy2.elements import WKTElement
@app.post("/add-station/")
async def add_station(station: Stations):
with Session(engine) as session:
exist = session.query(Stations).filter(
Stations.station_id == station.station_id).first()
if exist:
raise HTTPException(
status_code=400, detail="Station already exists")
#New code block
point = f"POINT({station.longitude} {station.latitude})"
station.geom = WKTElement(point, srid=4326)
#New code block
session.add(station)
session.commit()
session.refresh(station)
#This is only done to return a clean dictionnar with a proper json format
to_return = {}
to_return["station_id"] = station.station_id
to_return["latitude"] = station.latitude
to_return["longitude"] = station.longitude
to_return["cp"] = station.cp
to_return['city'] = station.city
to_return["adress"] = station.adress
return to_return
在这里,我们通过字符串创建一个点,使用另一种名为 WKTElement 的格式,这是一种使用人类可读的字符串编码几何图形的方式。我们的字符串随后通过 geolalchemy 函数 WKTElement 转换为几何图形,该函数隐式地将其转换为 WKB 格式以便编码到数据库中。
注意,“geom”不是 JSON 可序列化的,因此我们需要在通过 API 发送站点对象之前修改或删除它。
构建最终的 GET 查询
GET 查询的目标是从通过邮政编码识别的城市中检索 30 公里半径内的所有站点,并显示查询中提到的某种类型的所有站点的燃料最新价格,并附上一些美化的信息,如规范化的地址或 Google 地图链接。
{
"lat": 49.1414,
"lon": 2.5087,
"city": "Orry-la-Ville",
"station_infos": [
{
"address": "Zi Route de Crouy 60530 Neuilly-en-Thelle",
"price_per_L": 1.58,
"price_tank": 95,
"delta_average": 25.1,
"better_average": 1,
"google_map_link": "https://www.google.com/maps/search/?api=1&query=Zi+Route+de+Crouy+60530+Neuilly-en-Thelle",
"distance": 19.140224654602328,
"latitude": 49.229,
"longitude": 2.282
}, ...
]
}
我们将分两步进行:
-
首先构建一个高效的 SQL 查询来执行连接和过滤操作
-
在通过 API 发送结果之前,使用 Python 函数修改查询的输出。
与其他参数通过请求体中的 JSON 传递的查询不同,我们在这里将使用另一种约定,即将查询参数直接传递在 URL 中,见下面的示例:
http://localhost:8000/stations/?oil_type=SP98&postal_code=60560
在 FastAPI 中,这可以通过简单地向用于构建端点的函数中添加输入来自然完成:
@app.get("/stations/")
async def get_prices(oil_type: str, postal_code: str):
with Session(engine) as session:
...
现在我们要首先检索的是与邮政编码相关联的城市的纬度和经度。如果没有与邮政编码关联的城市,API 应该返回一个错误代码,说明未找到邮政编码。
city = session.query(Cities).filter(
Cities.postal_code == postal_code
).first()
if not city:
raise HTTPException(
status_code=404, detail="Postal Code not found")
接下来,我们将构建一系列子查询。每个子查询在最终查询完全执行之前不会被评估。这将帮助我们保持代码的可读性,并优化查询,因为 sqlalchemy ORM 会根据这些子查询动态优化查询。
我们要执行的第一个子查询是从 Stations 表中选择所有在已查询城市 30 公里半径范围内的车站。
stations = session.query(
Stations.station_id, Stations.adress, Stations.cp, Stations.city,
Stations.latitude, Stations.longitude,
).filter(
ST_Distance(
Stations.geom.ST_GeogFromWKB(),
WKTElement(f"POINT({city.lon} {city.lat})",
srid=4326).ST_GeogFromWKB()
) < 30000).subquery()
这里有许多有趣的地方需要注意。
-
我们只在 session.query( … ) 中选择了少量列,并且不保留 geom 列,该列仅用于过滤。在标准 SQL 中,这可以通过 “SELECT station_id, adress, cp, city, latitude, longitude FROM stations” 来完成。
-
我们使用ST_Distance,这是 geoalchemy 的内置函数,用于计算两个地理位置之间的距离(另一种 geoalchemy 类型)。
-
ST_Distance 也可以与几何对象一起工作,但输出将变成角度距离(请记住,纬度/经度是以角度表示的),这不是我们想要的。
-
要将几何对象转换为地理对象,我们只需使用另一个内置函数 ST_GeoFromWKB,它会自动将我们的几何体投影到其参考系统中,以在地球上形成一个点。
接下来,我们根据所需的 oil_type(如 SP95、Gazole 等)过滤 Gasprices 表。
price_wanted_gas = session.query(GasPrices).filter(
GasPrices.nom == oil_type
).subquery()
我们还需要根据数据集中最新的价格来过滤 Gasprices 表。这不是一项容易的任务,因为所有价格的更新不是同时完成的。我们将分两步构建子查询。
首先,我们通过从 price_wanted_gas 子表中提取 station_id 和最后更新时间来执行聚合。
last_price = session.query(
price_wanted_gas.c.station_id,
func.max(price_wanted_gas.c.maj).label("max_maj")
).group_by(price_wanted_gas.c.station_id) \
.subquery()
然后使用这些信息帮助我们通过连接过滤 price_wanted_gas,其中仅保留最新更新价格的行。“and_” 方法允许我们在连接操作中使用多个条件。
last_price_full = session.query(price_wanted_gas).join(
last_price,
and_(
price_wanted_gas.c.station_id == last_price.c.station_id,
price_wanted_gas.c.maj == last_price.c.max_maj
)
).subquery()
最后,我们在 last_price_full 子表(包含给定燃料的所有最新价格)和 stations 子表(包括所有在 30 公里半径内的车站)之间进行最终连接,并检索所有结果。
stations_with_price = session.query(stations, last_price_full).join(
last_price_full,
stations.c.station_id == last_price_full.c.station_id
).all()
到达这一点时,我们检索了经过过滤的相关车站列表,并将其与 GasPrices 表中的相关信息(即:价格)合并,我们只需要对输出结果进行后处理,以符合前端的要求。由于此时表格已经被清理和过滤,因此最终的后处理步骤可以在原生 Python 中完成,而不会对性能产生太大影响。
我将稍微详细说明这个最终的后处理步骤,因为它不在文章的核心部分,但请随时查看 GitHub 仓库以获取更多信息。
prices = [float(e["valeur"]) for e in stations_with_price]
avg_price = float(np.median(prices))
output = {
"lat": city.lat,
"lon": city.lon,
"city": pretify_address(city.name),
"station_infos": sorted([extend_dict(x, avg_price, city.lat, city.lon) for x in stations_with_price], key=lambda x: -(x['delta_average']))
}
return output
我们现在可以测试并验证查询是否返回了相关输出。我们可以使用 Python 请求进行检查,但 FastAPI 还提供了所有端点的内置文档,您可以在 localhost:8000/docs
测试您的 API。
FastAPI 内置文档的截图,作者插图
容器化应用程序
现在我们有了一个运行中的 API,我们将通过将应用程序打包到容器中来完成本文。
这是我们项目的组织方式:
stations-project/
|-- db/
|-- api/
|-- app/
|-- requirements.txt
|-- Dockerfile
|-- update_scripts/
|-- front/
|-- docker-compose.yml
我们将使用 api/ 中的 Dockerfile 来容器化 API,并使用 docker-compose 同时管理 API 和数据库。
文件夹 db/ 是 PostgreSQL 容器用来持久化数据库的卷。
打包我们的 API
为了打包我们的 API,我们将简单地构建一个 docker 镜像,该镜像将复制运行 API 所需的环境和依赖项。这个 docker 镜像将包含运行 API 所需的所有内容,包括代码、运行时、系统工具、库和配置。
为此,我们需要编写一个 Dockerfile,其中包含设置 FastAPI 环境的一系列指令。编写 Dockerfile 相对容易,只要理解了原理:它就像是从头开始配置一台新机器。在我们的案例中:
-
我们需要安装相关版本的 Python
-
设置工作目录
-
将相关文件复制到我们的工作目录中(包括 requirements.txt 文件,该文件是使用 pip install 安装项目所需所有库的强制要求)
-
使用 pip install 安装库
-
暴露 FastAPI 端口
-
运行初始化 API 的命令(uvicorn main:app — reload)
用 Docker 语言翻译过来,这变成了:
FROM python:3.9
WORKDIR /code
COPY ./requirements.txt /code/requirements.txt
COPY ./app /code/app
RUN pip install --no-cache-dir -r requirements.txt
EXPOSE 80
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80"]
我们还需要处理 requirements.txt 文件,其中明确列出所有使用的库及其版本。
fastapi==0.94.0
GeoAlchemy2==0.13.1
numpy==1.24.2
SQLAlchemy==1.4.41
sqlmodel==0.0.8
uvicorn==0.20.0
psycopg2==2.9.5
在进行这些更新之后,我们现在可以构建容器镜像(在包含 Dockerfile 的文件夹内):
docker build -t fast-api-stations .
使用 docker-compose
docker-compose 是一个用于定义和运行多容器 Docker 应用程序的工具。在我们的案例中,我们希望同时运行 SQL 容器和 FastAPI 容器。我将假设您已经在计算机上安装了 docker-compose。如果没有,请按照这些说明进行操作。
为了使用 docker-compose,我们只需在项目的根目录中配置一个 docker-compose.yml
文件,该文件定义了构成应用程序的服务及其各自的配置。
docker-compose.yml
文件使用 YAML 语法定义了一组服务,每个服务代表一个将作为全球应用程序的一部分运行的容器。每个服务可以指定其镜像、构建上下文、环境变量、持久化卷、端口等……
这就是我们的 docker-compose.yml 文件的样子:
version: "3"
services:
fastapi:
image: fast-api-stations
ports:
- "8000:80"
stationdb:
image: postgis/postgis
environment:
POSTGRES_USER: jkaub
POSTGRES_PASSWORD: jkaub
POSTGRES_DB: stations
volumes:
- ./db:/var/lib/postgresql/data
如您所见,我们定义了两个服务:
-
一个用于 API 的服务,现在名为 FastAPI,构建在我们在前一个小节中创建的 Docker 镜像 fast-api-station 上。对于这个服务,我们将容器的 80 端口暴露给本地的 8000 端口。
-
一个用于 DB,运行在 PostGIS 镜像上。我们指定了与之前相同的环境变量和相同的卷以持久化数据库。
最后一个小修改
我们曾经使用本地 IP 连接到 SQL 引擎。由于我们现在在两个不同的环境中运行 API 和 PostgreSQL,我们需要更改连接数据库的方式。
docker-compose 自行管理不同容器之间的网络,并使我们能够轻松地从一个服务连接到另一个服务。为了从 API 服务连接到 SQL 服务,我们可以在引擎创建时指定要连接的服务名称:
DATABASE_URL = 'postgresql://jkaub:jkaub@stationdb/stations'
运行后端
现在我们已经配置好了所有内容,我们可以通过以下方式运行我们的后端应用程序:
docker-compose up
API 将通过 8000 端口提供服务。
http://localhost:8000/docs
结论
在这篇文章中,我们一直在处理我们 GasFinder 应用程序的后端。
我们决定将应用程序的所有相关数据存储在我们自己的存储解决方案中,以避免所有可能与依赖第三方连接相关的问题。
我们利用了 Docker 和 PostgreSQL+PostGIS 构建了一个数据库,使我们能够执行高效的地理查询,并使用 Python 框架 FastAPI + SQLModel 构建了一个高效的 API,可以用来与数据库交互,并向前端提供数据,这些前端是在之前的文章中开发的。
目前,我们有一个基于“生产标准”工具(React、PostgreSQL、FastAPI 等)的原型,可以在本地 100%运行。在本系列的最后部分,我们将看看如何使应用程序上线并自动更新我们的 SQL 表,以始终提供最新的信息。
使用这个技巧构建更好的条形图
原文:
towardsdatascience.com/build-a-better-bar-chart-with-this-trick-c66979cb17e1
(这实际上是一个 seaborn 散点图!)
·发布在 Towards Data Science ·7 min read·2023 年 8 月 26 日
–
“国会年龄”散点图的一部分(所有图片由作者提供)
每当我需要寻找有效的可视化灵感时,我都会浏览 经济学人、视觉资本家 或 华盛顿邮报。在其中一次探索中,我发现了一个有趣的信息图表——类似于上面展示的图表——它绘制了每个美国国会议员的年龄与他们的代际群体之间的关系。
我的第一印象是这是一个 水平条形图,但仔细观察发现每个条形图由多个 标记 组成,使其成为一个 散点图。每个标记代表一个国会成员。
在这个 快速成功数据科学 项目中,我们将使用 Python、pandas 和 seaborn 重建这个吸引人的图表。在这个过程中,我们将揭示一些你可能不知道存在的标记类型。
数据集
由于美国有 候选资格年龄 法律,国会成员的生日属于公开记录。你可以在多个地方找到它们,包括 美国国会传记名录 和 维基百科。
为了方便,我已经编制了一个包含当前国会议员姓名、生日、政府分支和政党的 CSV 文件,并将其存储在这个 Gist 中。
代码
以下代码是在 Jupyter Lab 中编写的,并且由单元格描述 描述。
导入库
from collections import defaultdict # For counting members by age.
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import patches # For drawing boxes on the plot.
import pandas as pd
import seaborn as sns
为代际数据分配常量
我们将对图表进行标注,以突出显示诸如婴儿潮一代和X 世代等代际群体。以下代码计算每个群体的当前年龄范围,并包含代际名称和突出显示颜色的列表。因为我们希望将这些列表视为常量,所以我们将名称大写,并使用下划线作为前缀。
# Prepare generational data for plotting as boxes on chart:
CURRENT_YEAR = 2023
_GEN_NAMES = ['Silent', 'Boomers', 'Gen X', 'Millennials', 'Gen Z']
_GEN_START_YR = [1928, 1946, 1965, 1981, 1997]
_GEN_END_YR = [1945, 1964, 1980, 1996, 2012]
_GEN_START_AGE = [CURRENT_YEAR - x for x in _GEN_END_YR]
_GEN_END_AGE = [CURRENT_YEAR - x for x in _GEN_START_YR]
_GEN_COLORS = ['lightgray', 'white', 'lightgray', 'white', 'lightgray']
将生日转换为年龄
为了计算每位成员的年龄,我们首先必须将参考日期(8/25/2023)和 DataFrame 的“Birthday”列转换为datetime格式,使用 pandas 的to_datetime()
方法。
现在我们有了兼容的、"日期感知"格式,我们可以通过减去这两个值,提取天数,然后将天数除以 365.25 来生成一个"年龄"列。
# Load the data:
df = pd.read_csv('https://bit.ly/3EdQrai')
# Assign the current date:
current_date = pd.to_datetime('8/25/2023')
# Convert "Birthday" column to datetime:
df['Birthday'] = pd.to_datetime(df['Birthday'])
# Make a new "Age" column in years:
df['Age'] = ((current_date - df['Birthday']).dt.days) / 365.25
df['Age'] = df['Age'].astype(int)
df.head(3)
初始 DataFrame 的头部
计算成员的年龄
我们最终希望按党派和政府分支对成员进行分组。这意味着我们需要生成四个单独的图表。(我们将与民主党一起包含 3 位独立人士,他们与民主党一起开会)。
与简单的条形图不同,我们需要知道的不仅仅是,比如说,57 岁共和党参议员的总数。由于我们想为特定年龄类别的每个成员绘制单独的标记,我们需要一个累计总数。这样,我们可以使用(计数,年龄)值作为散点图中的(x,y)坐标。因此,第一位 57 岁共和党参议员将在计数列中标记为“1”,第二位参议员标记为“2”,以此类推。
为了管理这一点,我们将首先设置四个 DataFrame 列来保存计数,然后制作四个相应的字典来记录初始计数。我们将使用collections
模块的[defaultdict()](https://docs.python.org/3/library/collections.html#defaultdict-objects)
容器,而不是标准字典,因为它会为不存在的键提供默认值,而不是引发令人烦恼的KeyError
。
接下来,我们将遍历 DataFrame,按“Branch”和“Party”列进行过滤。每次我们增加字典时,我们将更新“Age”列。这使我们能够保持匹配年龄的累积计数。
请注意,我们使用负数值来表示民主党计数,因为我们希望它们绘制在中央轴的左侧,而共和党的值绘制在右侧。
# Initialize count columns:
df['R count house'] = 0
df['D count house'] = 0
df['R count senate'] = 0
df['D count senate'] = 0
# Create dictionaries with default values of 0:
r_count_h_dict = defaultdict(int)
d_count_h_dict = defaultdict(int)
r_count_s_dict = defaultdict(int)
d_count_s_dict = defaultdict(int)
# Iterate through the DataFrame and update counts:
for index, row in df.iterrows():
age = row['Age']
if row['Branch'] == 'House':
if row['Party'] == 'R':
r_count_h_dict[age] += 1
df.at[index, 'R count house'] = r_count_h_dict[age]
elif row['Party'] == 'D':
d_count_h_dict[age] -= 1
df.at[index, 'D count house'] = d_count_h_dict[age]
elif row['Branch'] == 'Senate':
if row['Party'] == 'R':
r_count_s_dict[age] += 1
df.at[index, 'R count senate'] = r_count_s_dict[age]
elif row['Party'] == 'D':
d_count_s_dict[age] -= 1
df.at[index, 'D count senate'] = d_count_s_dict[age]
elif row['Party'] == 'I':
d_count_s_dict[age] -= 1
df.at[index, 'D count senate'] = d_count_s_dict[age]
df.head(3)
遮蔽零计数
我们不想绘制零值,因此我们将使用掩码将这些值转换为 DataFrame 中的NaN
(非数字)值。
# Filter out zero values:
mask = df != 0
# Apply the mask to the DataFrame:
df = df[mask]
df.head(3)
定义绘图函数
如前所述,我们将制作四个图表。为了避免重复代码,我们将把绘图指令封装到一个可重用的函数中。
这个函数将接受一个 DataFrame、一个 matplotlib 轴对象的名称、作为 x 坐标的列、一个颜色和一个标题作为参数。我们会关闭 seaborn 的大部分默认设置,比如轴刻度和标签,以便我们的图形尽可能干净和简洁。
这个图的一个重要组成部分是用于每个国会议员的矩形标记(marker=$\u25AC$
)。这个标记不是标准 matplotlib 集合的一部分,而是STIX 字体符号的一部分。你可以在这里找到这些替代标记的列表。
def make_plot(data, ax, x, color, title):
"""Make a custom seaborn scatterplot with annotations."""
sns.scatterplot(data=data,
x=x,
y='Age',
marker='$\u25AC$',
color=color,
edgecolor=color,
ax=ax,
legend=False)
# Set the border positions and visibility:
ax.spines.left.set_position('zero')
ax.spines.right.set_color('none')
ax.spines.top.set_color('none')
ax.spines.bottom.set_color('none')
# Set x and y limits, ticks, labels, and title:
ax.set_xlim(-15, 15)
ax.set_ylim(25, 100)
ax.tick_params(bottom=False)
ax.set(xticklabels=[])
ax.set(yticklabels=[])
ax.set_xlabel('')
ax.set_ylabel('')
ax.set_title(title)
# Manually annotate the y-axis along the right border:
ax.text(x=12.5, y=96, s='Age')
ax.set_yticks(np.arange(30, 101, 10))
ylabels = [30, 40, 50, 60, 70, 80, 90]
for label in ylabels:
ax.text(x=13, y=label, s=str(label))
# Add shading and annotation for each generation:
for _, (name, start_age, end_age, gcolor) in enumerate(zip(_GEN_NAMES,
_GEN_START_AGE,
_GEN_END_AGE,
_GEN_COLORS)):
rect = patches.Rectangle((-15, start_age),
width=30,
height=end_age - start_age,
facecolor=gcolor,
alpha=0.3)
rect.set_zorder(0) # Move shading below other elements.
ax.add_patch(rect)
ax.text(x=-15, y=end_age - 2, s=name)
plt.tight_layout()
绘制图形
以下代码设置了图形并调用了make_plot()
函数四次。最后添加了超级标题和自定义图例。
# Make the figure and call the plotting function:
fig, (ax0, ax1) = plt.subplots(nrows=1, ncols=2, figsize=(8, 5))
make_plot(df, ax0, 'D count house', 'blue', 'House' )
make_plot(df, ax0, 'R count house', 'firebrick', 'House')
make_plot(df, ax1, 'D count senate', 'blue', 'Senate')
make_plot(df, ax1, 'R count senate', 'firebrick', 'Senate')
# Add figure title and custom legend:
fig.suptitle('Age of US Congress 2023')
ax0.text(x=-15, y=17, s='$\u25AC$ Democrat & Independent', color='blue')
ax0.text(x=1.7, y=17, s='$\u25AC$ Republican', color='firebrick');
# Optional line to save figure:
# plt.savefig('age_of_congress.png', bbox_inches='tight', dpi=600)
最终的图形。
结论
最佳信息图以干净、引人注目的风格讲述故事。正如写得很好的 Python 代码几乎不需要注释一样,优秀的信息图也不需要很多标签或注释。
在这个项目中,我们使用 pandas 加载和准备数据,并使用 seaborn 生成一个模仿条形图的散点图。这个图的一个关键特性是使用STIX 字体符号作为矩形标记。
对于具有许多低计数值的数据集,这种散点图方法比标准条形图更具视觉吸引力,因为条形图中的许多条形会很短。此外,用不同的标记表示每个成员比为多个成员显示单一条形图更能“个性化”数据。
谢谢!
感谢阅读,请关注我以获取更多快速成功数据科学项目。
使用 Reflex 在纯 Python 中构建一个类似 ChatGPT 的 Web 应用
使用 OpenAI 的 API 在纯 Python 中构建一个聊天 Web 应用,部署只需一行代码
·
关注 发表在 Towards Data Science ·8 分钟阅读·2023 年 11 月 7 日
–
聊天应用 GIF 由作者提供
在过去的几个月里,我一直在玩各种令人惊叹的新 LLM 聊天机器人,包括 Llama 2、GPT-4、Falcon 40B 和 Claude 2。一个始终困扰我的问题是,我如何构建自己的聊天机器人界面,调用所有这些出色的 LLM 作为 API?
现在有无数的选项可以用来构建美丽的用户界面,但作为一名机器学习工程师,我对 JavaScript 或任何前端语言都没有经验。我在寻找一种只使用我目前知道的语言——Python 来构建我的 Web 应用的方法!
我决定使用一个相对较新的开源框架 Reflex,它允许我完全用 Python 构建后端和前端。
免责声明: 我在 Reflex 担任创始工程师,负责对开源框架做出贡献。
在本教程中,我们将讲解如何从头开始使用纯 Python 构建一个完整的 AI 聊天应用——你还可以在这个 Github 仓库 找到所有代码。
你将学到如何:
-
安装
reflex
并设置你的开发环境。 -
创建组件来定义和设计你的 UI。
-
使用状态为你的应用添加交互性。
-
使用一行命令部署你的应用,与你人分享。
设置你的项目
我们将从创建一个新项目和设置开发环境开始。首先,为你的项目创建一个新目录并进入该目录。
~ $ mkdir chatapp
~ $ cd chatapp
接下来,我们将为我们的项目创建一个虚拟环境。在这个示例中,我们将使用 venv 来创建虚拟环境。
chatapp $ python3 -m venv .venv
$ source .venv/bin/activate
现在,我们将安装 Reflex 并创建一个新项目。这将创建项目目录中的新目录结构。
chatapp $ pip install reflex
chatapp $ reflex init
────────────────────────────────── Initializing chatapp ───────────────────────────────────
Success: Initialized chatapp
chatapp $ ls
assets chatapp rxconfig.py .venv
你可以运行模板应用来确保一切正常。
chatapp $ reflex run
─────────────────────────────────── Starting Reflex App ───────────────────────────────────
Compiling: ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 1/1 0:00:00
─────────────────────────────────────── App Running ───────────────────────────────────────
App running at: http://localhost:3000
你应该能在 localhost:3000
看到你的应用运行。
Reflex 还会启动后台服务器,处理所有状态管理和与前端的通信。你可以通过访问 localhost:8000/ping
来测试后台服务器是否正在运行。
现在我们已经设置好了项目,开始构建我们的应用吧!
基本前端
让我们从定义聊天应用的前端开始。在 Reflex 中,前端可以拆分为独立的、可重用的组件。有关更多信息,请查看 组件文档。
显示问题和答案
我们将修改 chatapp/chatapp.py
文件中的 index
函数,以返回一个显示单个问题和答案的组件。
作者提供的图片(代码如下)
# chatapp.py
import reflex as rx
def index() -> rx.Component:
return rx.container(
rx.box(
"What is Reflex?",
# The user's question is on the right.
text_align="right",
),
rx.box(
"A way to build web apps in pure Python!",
# The answer is on the left.
text_align="left",
),
)
# Add state and page to the app.
app = rx.App()
app.add_page(index)
app.compile()
组件可以相互嵌套以创建复杂的布局。在这里,我们创建了一个父容器,其中包含两个框,用于显示问题和答案。
我们还为组件添加了一些基本的样式。组件接受关键字参数,称为 props,这些参数可以修改组件的外观和功能。我们使用 text_align
属性将文本对齐到左侧和右侧。
重用组件
现在我们有了一个显示单个问题和答案的组件,我们可以重用它来显示多个问题和答案。我们将把该组件移动到一个单独的函数 question_answer
中,并从 index
函数中调用它。
作者提供的图片(代码如下)
def qa(question: str, answer: str) -> rx.Component:
return rx.box(
rx.box(question, text_align="right"),
rx.box(answer, text_align="left"),
margin_y="1em",
)
def chat() -> rx.Component:
qa_pairs = [
(
"What is Reflex?",
"A way to build web apps in pure Python!",
),
(
"What can I make with it?",
"Anything from a simple website to a complex web app!",
),
]
return rx.box(
*[
qa(question, answer)
for question, answer in qa_pairs
]
)
def index() -> rx.Component:
return rx.container(chat())
聊天输入
现在我们希望用户能够输入一个问题。为此,我们将使用input组件让用户添加文本,并使用button组件来提交问题。
作者提供的图片(下面的代码)
def action_bar() -> rx.Component:
return rx.hstack(
rx.input(placeholder="Ask a question"),
rx.button("Ask"),
)
def index() -> rx.Component:
return rx.container(
chat(),
action_bar(),
)
样式
让我们给应用添加一些样式。有关样式的更多信息可以在styling docs中找到。为了保持代码的整洁,我们将样式移动到一个单独的文件chatapp/style.py
中。
# style.py
# Common styles for questions and answers.
shadow = "rgba(0, 0, 0, 0.15) 0px 2px 8px"
chat_margin = "20%"
message_style = dict(
padding="1em",
border_radius="5px",
margin_y="0.5em",
box_shadow=shadow,
max_width="30em",
display="inline-block",
)
# Set specific styles for questions and answers.
question_style = message_style | dict(
bg="#F5EFFE", margin_left=chat_margin
)
answer_style = message_style | dict(
bg="#DEEAFD", margin_right=chat_margin
)
# Styles for the action bar.
input_style = dict(
border_width="1px", padding="1em", box_shadow=shadow
)
button_style = dict(bg="#CEFFEE", box_shadow=shadow)
我们将导入chatapp.py
中的样式并在组件中使用它们。此时,应用应该如下所示:
作者提供的图片
# chatapp.py
import reflex as rx
from chatapp import style
def qa(question: str, answer: str) -> rx.Component:
return rx.box(
rx.box(
rx.text(question, style=style.question_style),
text_align="right",
),
rx.box(
rx.text(answer, style=style.answer_style),
text_align="left",
),
margin_y="1em",
)
def chat() -> rx.Component:
qa_pairs = [
(
"What is Reflex?",
"A way to build web apps in pure Python!",
),
(
"What can I make with it?",
"Anything from a simple website to a complex web app!",
),
]
return rx.box(
*[
qa(question, answer)
for question, answer in qa_pairs
]
)
def action_bar() -> rx.Component:
return rx.hstack(
rx.input(
placeholder="Ask a question",
style=style.input_style,
),
rx.button("Ask", style=style.button_style),
)
def index() -> rx.Component:
return rx.container(
chat(),
action_bar(),
)
app = rx.App()
app.add_page(index)
app.compile()
应用看起来不错,但还不是很有用!现在让我们添加一些功能。
状态
现在让我们通过添加状态来使聊天应用变得互动。状态是我们定义应用中所有可以变化的变量以及所有可以修改这些变量的函数的地方。你可以在state docs中了解更多关于状态的信息。
定义状态
我们将在chatapp
目录中创建一个名为state.py
的新文件。我们的状态将跟踪当前提出的问题和聊天记录。我们还将定义一个事件处理程序answer
,它将处理当前的问题并将答案添加到聊天记录中。
# state.py
import reflex as rx
class State(rx.State):
# The current question being asked.
question: str
# Keep track of the chat history as a list of (question, answer) tuples.
chat_history: list[tuple[str, str]]
def answer(self):
# Our chatbot is not very smart right now...
answer = "I don't know!"
self.chat_history.append((self.question, answer))
将状态绑定到组件
现在我们可以在chatapp.py
中导入状态,并在前端组件中引用它。我们将修改chat
组件,以使用状态代替当前固定的问题和答案。
作者提供的图片
# chatapp.py
from chatapp.state import State
...
def chat() -> rx.Component:
return rx.box(
rx.foreach(
State.chat_history,
lambda messages: qa(messages[0], messages[1]),
)
)
...
def action_bar() -> rx.Component:
return rx.hstack(
rx.input(
placeholder="Ask a question",
on_change=State.set_question,
style=style.input_style,
),
rx.button(
"Ask",
on_click=State.answer,
style=style.button_style,
),
)
普通的 Python for
循环无法用于遍历状态变量,因为这些值可能会变化且在编译时未知。相反,我们使用foreach组件来遍历聊天记录。
我们还将输入框的on_change
事件绑定到set_question
事件处理程序,这将更新question
状态变量,而用户在输入框中输入时。我们将按钮的on_click
事件绑定到answer
事件处理程序,这将处理问题并将答案添加到聊天记录中。set_question
事件处理程序是一个内置的隐式定义事件处理程序。每个基础变量都有一个。更多信息请查看events docs中的 Setters 部分。
清空输入框
目前,用户点击按钮后输入框不会清空。我们可以通过将输入框的值绑定到question
,设置为value=State.question
,并在运行answer
事件处理程序时将其清空,使用self.question = ''
来解决这个问题。
# chatapp.py
def action_bar() -> rx.Component:
return rx.hstack(
rx.input(
value=State.question,
placeholder="Ask a question",
on_change=State.set_question,
style=style.input_style,
),
rx.button(
"Ask",
on_click=State.answer,
style=style.button_style,
),
)
# state.py
def answer(self):
# Our chatbot is not very smart right now...
answer = "I don't know!"
self.chat_history.append((self.question, answer))
self.question = ""
流式文本
通常,状态更新会在事件处理程序返回时发送到前端。然而,我们希望在生成的过程中流式传输来自聊天机器人的文本。我们可以通过从事件处理程序中生成来实现。有关更多信息,请参见事件生成文档。
# state.py
import asyncio
...
async def answer(self):
# Our chatbot is not very smart right now...
answer = "I don't know!"
self.chat_history.append((self.question, ""))
# Clear the question input.
self.question = ""
# Yield here to clear the frontend input before continuing.
yield
for i in range(len(answer)):
# Pause to show the streaming effect.
await asyncio.sleep(0.1)
# Add one letter at a time to the output.
self.chat_history[-1] = (
self.chat_history[-1][0],
answer[: i + 1],
)
yield
使用 API
我们将使用 OpenAI 的 API 为我们的聊天机器人提供一些智能。我们需要修改事件处理程序以向 API 发送请求。
# state.py
import os
import openai
openai.api_key = os.environ["OPENAI_API_KEY"]
...
def answer(self):
# Our chatbot has some brains now!
session = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[
{"role": "user", "content": self.question}
],
stop=None,
temperature=0.7,
stream=True,
)
# Add to the answer as the chatbot responds.
answer = ""
self.chat_history.append((self.question, answer))
# Clear the question input.
self.question = ""
# Yield here to clear the frontend input before continuing.
yield
for item in session:
if hasattr(item.choices[0].delta, "content"):
answer += item.choices[0].delta.content
self.chat_history[-1] = (
self.chat_history[-1][0],
answer,
)
yield
最后,我们有了我们的 AI 聊天机器人!
结论
按照这个教程,我们已经成功创建了使用 OpenAI API 密钥的聊天应用,完全用 Python 编写。
要现在运行这个应用,我们可以运行简单的命令:
$ reflex run
要部署它,以便我们可以与其他用户分享,我们可以运行以下命令:
$ reflex deploy
我希望这个教程能激发你们构建自己的 LLM 基础应用。我迫不及待想看到你们最终开发出什么,所以请在社交媒体或评论中与我联系。
如果你有问题,请在下方评论或者通过 Twitter 上的@tgotsman12或LinkedIn给我发消息。分享你的应用创作到社交媒体并标记我,我很乐意提供反馈或帮助转发!
使用 Numpy 从零开始构建卷积神经网络
原文:
towardsdatascience.com/build-a-convolutional-neural-network-from-scratch-using-numpy-139cbbf3c45e
通过从零开始构建 CNN 来掌握计算机视觉
·发表于Towards Data Science·8 分钟阅读·2023 年 11 月 23 日
–
这些彩色窗户让我想起了 CNN 的层及其过滤器。图片来源:unsplash.com。
由于计算机视觉应用现在无处不在,每个数据科学从业者都必须了解其工作原理和熟悉它们。
在这篇文章中,我在不依赖流行的现代深度学习库如 Tensorflow、Pytorch 和 Keras 的情况下构建了深度神经网络。我随后用它对手写数字进行分类。尽管取得的结果未达到最先进的水平,但仍然令人满意。现在,我想在使用 Python 库Numpy的情况下迈出进一步一步,开发一个卷积神经网络(CNN)。
Python 深度学习库,如上所述,是极其强大的工具。然而,作为一个缺点,它们使数据科学从业者无法理解神经网络的低级工作原理。这一点在 CNN 中尤其明显,因为它们的过程比经典的全连接网络更不直观。解决这一问题的唯一方法就是亲自动手实现 CNN:这就是这项任务的动机。
这篇文章旨在作为一个实用的、动手指南,而不是一个全面的 CNN 工作原理指南。因此,理论部分简明扼要,主要服务于实际部分的理解。为此,你会在本文末尾找到一份详尽的资源列表。我热情邀请你去查看它们!
卷积神经网络
卷积神经网络使用一种特定的架构和操作,使其非常适合与图像相关的任务,例如图像分类、目标定位、图像分割等。它们的设计大致反映了人类视觉皮层,其中每个生物神经元仅响应视觉场的一个小部分。此外,更高级的神经元对低级神经元的输出做出反应。
虽然经典的全连接网络可以处理与图像相关的任务,但当应用于中等或大型图像时,由于所需的参数数量众多,它们的效果会显著下降。例如,一个 200x200 像素的图像包含 40,000 个像素,如果网络的第一层有 1,000 个单元,仅这一层就需要 4000 万个权重。这个挑战通过 CNN 得到了极大的缓解,因为它们实现了部分连接层和权重共享。
卷积神经网络的主要组件是:
-
卷积层
-
池化层
卷积层
卷积层由一组滤波器组成,也称为内核。当应用于层的输入时,这些滤波器以特定方式修改原始图像。
一个滤波器可以描述为一个矩阵,其元素值定义了对原始图像应用的修改类型。例如,如下的 3x3 内核突出了图像的垂直边缘:
这个内核则强调水平边缘:
来源: 维基百科.
需要注意的是,这些内核的元素值不是手动选择的,而是网络在训练过程中学习的参数。
卷积的主要功能是隔离并突出图像中存在的不同特征。稍后,密集层将使用这些特征。
池化层
池化层比卷积层更简单。它们的目的是最小化网络的计算负担和内存使用。它们通过缩小输入图像的尺寸来实现这一目标。降低维度会减少 CNN 需要学习的参数数量。
池化层也使用一个通常为 2x2 维度的内核,将输入图像的一部分汇聚成一个值。例如,一个 2x2 的最大池化内核从输入图像中提取 4 个像素,并只输出值最大的像素。
Python 实现
你可以在我的 GitHub 仓库中找到本节中展示的所有代码。
[## GitHub - andreoniriccardo/CNN-from-scratch: 从零开始构建卷积神经网络
从零开始构建卷积神经网络。通过创建一个…来贡献到 andreoniriccardo/CNN-from-scratch 的开发。
github.com](https://github.com/andreoniriccardo/CNN-from-scratch?source=post_page-----139cbbf3c45e--------------------------------)
该实现的概念是创建Python 类,代表卷积层和最大池化层。此外,由于此 CNN 将应用于著名的开源 MNIST 数据集,我还创建了一个用于 Softmax 层的特定类。
在每个类中,我定义了执行前向传播和反向传播步骤的方法。
作为最后一步,将这些层附加到一个列表中以构建最终的卷积神经网络。
卷积层实现
定义卷积层的代码如下:
class ConvolutionLayer:
def __init__(self, kernel_num, kernel_size):
self.kernel_num = kernel_num
self.kernel_size = kernel_size
self.kernels = np.random.randn(kernel_num, kernel_size, kernel_size) / (kernel_size**2)
def patches_generator(self, image):
image_h, image_w = image.shape
self.image = image
for h in range(image_h-self.kernel_size+1):
for w in range(image_w-self.kernel_size+1):
patch = image[h:(h+self.kernel_size), w:(w+self.kernel_size)]
yield patch, h, w
def forward_prop(self, image):
image_h, image_w = image.shape
convolution_output = np.zeros((image_h-self.kernel_size+1, image_w-self.kernel_size+1, self.kernel_num))
for patch, h, w in self.patches_generator(image):
convolution_output[h,w] = np.sum(patch*self.kernels, axis=(1,2))
return convolution_output
def back_prop(self, dE_dY, alpha):
dE_dk = np.zeros(self.kernels.shape)
for patch, h, w in self.patches_generator(self.image):
for f in range(self.kernel_num):
dE_dk[f] += patch * dE_dY[h, w, f]
self.kernels -= alpha*dE_dk
return dE_dk
**ConvolutionLayer**
类的构造函数接收卷积层的内核数量及其大小作为输入。我假设只使用大小为**kernel_size**
乘**kernel_size**
的方形内核**。
然后,我生成形状为**(kernel_num, kernel_size, kernel_size)**
的随机滤波器,并通过将每个元素除以平方的内核大小进行归一化。
**patches_generator()**
方法是一个生成器。它生成要进行每个卷积步骤的图像部分。
**forward_prop()**
方法对上述方法生成的每个补丁执行卷积操作。
最后,**back_prop()**
方法负责计算损失函数相对于每层权重的梯度。它还相应地更新权重值。请注意,这里提到的损失函数不是网络的全局损失,而是由最大池化层传递给前一个卷积层的损失函数。
为了演示此类的实际效果,我创建了一个 **ConvolutionLayer**
实例,具有 32 个滤波器,每个滤波器的大小为 3x3。然后,我在一张图像上应用前向传播方法,得到由 32 张稍小的图像组成的输出。
初始输入图像的大小为 28x28 像素,如下图所示:
图片来源于作者。
一旦我应用了卷积层的**forward_prop()**
方法,我得到 32 张 26x26 像素的图像。其中之一如下:
图片由作者提供。
正如你所见,图像的尺寸被缩小了,手写数字的清晰度更差了。需要注意的是,这一操作是由一个包含随机值的滤波器执行的,因此,它并不准确代表经过训练的 CNN 实际执行的步骤。尽管如此,你可以理解这些卷积如何产生较小的图像,在这些图像中对象的特征被分离出来。
Max Pooling 层实现
我使用 Numpy 定义了 Max Pooling 层类,如下所示:
class MaxPoolingLayer:
def __init__(self, kernel_size):
self.kernel_size = kernel_size
def patches_generator(self, image):
output_h = image.shape[0] // self.kernel_size
output_w = image.shape[1] // self.kernel_size
self.image = image
for h in range(output_h):
for w in range(output_w):
patch = image[(h*self.kernel_size):(h*self.kernel_size+self.kernel_size), (w*self.kernel_size):(w*self.kernel_size+self.kernel_size)]
yield patch, h, w
def forward_prop(self, image):
image_h, image_w, num_kernels = image.shape
max_pooling_output = np.zeros((image_h//self.kernel_size, image_w//self.kernel_size, num_kernels))
for patch, h, w in self.patches_generator(image):
max_pooling_output[h,w] = np.amax(patch, axis=(0,1))
return max_pooling_output
def back_prop(self, dE_dY):
dE_dk = np.zeros(self.image.shape)
for patch,h,w in self.patches_generator(self.image):
image_h, image_w, num_kernels = patch.shape
max_val = np.amax(patch, axis=(0,1))
for idx_h in range(image_h):
for idx_w in range(image_w):
for idx_k in range(num_kernels):
if patch[idx_h,idx_w,idx_k] == max_val[idx_k]:
dE_dk[h*self.kernel_size+idx_h, w*self.kernel_size+idx_w, idx_k] = dE_dY[h,w,idx_k]
return dE_dk
构造方法仅分配内核大小值。接下来的方法与卷积层中定义的方法类似,主要的区别在于**back_prop()**
方法不会更新任何权重值。实际上,池化层不依赖于权重来执行聚合操作。
Softmax 层实现
最后,我定义了Softmax 层。它的目的是展平从最终 Max Pooling 层获得的输出体积。Softmax 层输出 10 个值,这些值可以被解释为图像对应于 0 到 9 数字的概率。
实现具有与上述相同的结构:
class SoftmaxLayer:
def __init__(self, input_units, output_units):
self.weight = np.random.randn(input_units, output_units)/input_units
self.bias = np.zeros(output_units)
def forward_prop(self, image):
self.original_shape = image.shape
image_flattened = image.flatten()
self.flattened_input = image_flattened
first_output = np.dot(image_flattened, self.weight) + self.bias
self.output = first_output
softmax_output = np.exp(first_output) / np.sum(np.exp(first_output), axis=0)
return softmax_output
def back_prop(self, dE_dY, alpha):
for i, gradient in enumerate(dE_dY):
if gradient == 0:
continue
transformation_eq = np.exp(self.output)
S_total = np.sum(transformation_eq)
dY_dZ = -transformation_eq[i]*transformation_eq / (S_total**2)
dY_dZ[i] = transformation_eq[i]*(S_total - transformation_eq[i]) / (S_total**2)
dZ_dw = self.flattened_input
dZ_db = 1
dZ_dX = self.weight
dE_dZ = gradient * dY_dZ
dE_dw = dZ_dw[np.newaxis].T @ dE_dZ[np.newaxis]
dE_db = dE_dZ * dZ_db
dE_dX = dZ_dX @ dE_dZ
self.weight -= alpha*dE_dw
self.bias -= alpha*dE_db
return dE_dX.reshape(self.original_shape)
图片由作者提供。
结论
在这篇文章中,我们看到对基本 CNN 架构元素如卷积层和池化层的理论介绍。我相信,逐步的 Python 实现将为你提供实际理解这些理论概念如何转化为代码的途径。
我邀请你克隆包含代码的GitHub 仓库并尝试**main.py**
脚本。当然,这个网络并没有达到最先进的性能,因为它不是为了这个目标而构建的,但在经过几个训练周期后,仍然达到了96%的准确率。
最后,为了扩展你对 CNN 和计算机视觉的知识,我建议你查看下面列出的一些资源。
如果你喜欢这个故事,请考虑关注我,以便了解我即将推出的项目和文章!
参考文献
-
“ImageNet 分类与深度卷积神经网络”由 Alex Krizhevsky, Ilya Sutskever, 和 Geoffrey Hinton
-
“《用于大规模图像识别的深度卷积网络》” 作者:Karen Simonyan 和 Andrew Zisserman (VGGNet)
-
“《Python 中的卷积神经网络:掌握数据科学和机器学习,现代深度学习方法,使用 Python、Theano 和 TensorFlow》” 作者:Jason Brownlee
如何为任何团队规模构建数据科学战略
创建一个快速变动且对变化具有弹性的文化和实践
·发表于 走向数据科学 ·20 分钟阅读·2023 年 9 月 11 日
–
照片由 Maarten van den Heuvel 拍摄,来源于 Unsplash
创建一个快速变动且对变化具有弹性的文化和实践
如果你是一个数据科学领导者,被要求在“建立我们的数据科学战略”时拥有很大的自由和很少的方向,这篇文章会对你有所帮助。我们将讨论:
-
我们所说的战略是什么:仅仅是一个计划?一个路线图?还是更多或更少的东西?在本节中,我们将具体化并采用一个工作定义,以了解我们在构建战略时实际构建了什么。
-
这个概念如何在实际的组织背景下应用于数据科学团队?在这里,我们将探讨我们对战略的概念如何适用于数据科学,并具体说明我们的战略应用于哪些方面。
-
如何实际制定该战略。
在整个过程中,我们将大量借鉴研发战略的方法,因为它与数据科学面临的关键挑战类似:创新的使命,以及追求发现所带来的不确定性。当我们结束时,你将获得一个明确的战略定义,以及一个适用于任何规模组织的有用的制定过程。
什么是战略?
如果你像我一样,没有高级 MBA 学位,也从未参加过商业战略研讨会,你可能会对有人要求你制定“数据科学战略”时究竟想要什么感到困惑。你可能会发现最初的搜索并没有多大帮助。像三 C 模型(顾客、竞争者、公司)这样的经典强大框架,在决定公司应该竞争的领域时非常合理。将其应用于一个职能或团队时,你会发现自己感觉在拉伸这些概念的承受范围。
如果你像我一样,阅读像战略领主和麦肯锡方法这样的书籍会让你陷入一个相当深入的阅读漩涡中。(附属链接。)前者是一本令人愉快的商业历史著作,后者是从成功的咨询公司经验中提炼出的有用技巧合集。两者都没有提供快速的答案。阅读《战略领主》的一个非常有益的副作用是了解到数据科学家们并不孤单:“[我]很容易把战略与战略规划混为一谈,但这也是危险的。[…] 即使今天,拥有计划的公司仍然远多于拥有战略的公司。刮掉大多数计划,你会发现某种版本的‘我们将继续做我们一直在做的事,但明年我们会做得更多和/或更好’。这种定义混淆在我的经验中也有所体现,几次对战略的请求实际上简化为‘你接下来几个月的计划是什么?’”
我们将在本文其余部分采用的一个非常有用的战略定义,来源于 Gary Pisano 的这篇关于研发战略的工作论文:“战略不过是一种对行为模式的承诺,旨在帮助赢得竞争。”这个定义的美妙之处在于它可以适用于组织的任何层级和目的。所有类型和规模的团队都参与组织的竞争努力,所有团队都可以定义并声明他们用来集中这些努力的行为模式。
战略不过是一种对行为模式的承诺,旨在帮助赢得竞争。
—Gary Pisano
Pisano 提出了一个好的战略的三个要求:一致性、一致性和对齐。战略应该帮助我们做出一致的决策,这些决策累积起来有助于实现预期目标;应该帮助组织的各个角落将其分散的战术决策协调一致;并且应该使地方行动与更大的集体努力保持一致。
最终,它们都建立在核心假设上,即关于在竞争中提供优势的赌注。皮萨诺的有用例子是苹果,其战略“开发易于使用、外观美观的产品,并与消费者数字世界中的更广泛设备系统无缝集成”建立在一个核心假设上:“客户将愿意为具有这些属性的产品支付显著更高的价格。”
本质上,根据这个定义,所有策略都是包装决策逻辑的赌注:它们为所有各方提供了确定哪些行动有助于集体努力的方法。
我们将采用这个策略定义,并努力定义我们自己的核心战略假设,关于数据科学如何为我们的组织增值,以及我们将在追求该价值的过程中坚持的模式。此外,我们将假设我们的母公司已经制定了自己的战略,这一输入在我们应用第三个对齐测试时将是至关重要的。在确定了我们最终战略的形式之后,我们将把注意力转向限定其范围。
我们所说的数据科学是什么意思,这个战略概念如何适用?
为了提醒我的朋友们我有多么有趣,我给几个人发了相同的短信,“你听到‘数据科学策略’时会想到什么?”答案从对数据基础设施和 MLOps 的深思熟虑,到对问题模糊性的健康反应(我觉得被看到了),再到丰富多彩的“胡说八道”和“我的理想工作”。
尽管样本较小,但来自这一群体的多样化回应——包括初创公司和大型公司的资深产品经理、数据科学负责人和顾问——表明了这个术语的定义可能有多么模糊。更糟糕的是,数据科学家还面临第二层次的困惑:所谓的“数据科学”在实践中往往取决于公司招聘所需的技能,并用流行的标题装点。
为了在我们的分析中固定这些自由度,我们将首先采用一个共同的数据科学定义来进行本文的其余部分:致力于通过建模组织的可用数据来创造价值和竞争优势的职能。这可以采取几种典型形式:
-
构建优化客户决策的机器学习模型以用于生产
-
建立帮助各级员工完成工作的模型,可能应用于客户互动的人工环节应用。
-
建立可解释的模型以辅助商业决策
注意,我们排除了 BI 和分析,仅仅是为了聚焦,而不是因为它们不如建模工作有价值。你的分析部门和数据科学部门应当顺利合作。(我曾在这里写过这方面的内容。)
一些人,比如我的朋友和谷歌产品经理卡罗尔·斯科达斯·沃尔波特,会建议数据科学战略包括“如何使数据和基础设施处于足够好的状态,以进行分析或机器学习。我会说这是如何使团队完成所有工作的。”我们将有意排除这些更广泛的数据战略项目(抱歉,卡罗尔)。不过,我们将讨论如何应对数据和基础设施限制,以及如何发展数据科学战略以积极指导更广泛的数据战略。
现在我们有了界限:我们正在构建一套核心战略假设,关于机器学习和/或人工智能如何为拥有自己定义战略或目标的组织增加最大价值,以及团队在追求这一价值时将承诺的一系列模式。我们该如何开始?
建立我们的战略核心假设:从一个赢得 AI 的心态开始
经验丰富的机器学习产品经理、工程师和数据科学家常常会提到,机器学习产品与传统软件不同。一个组织必须考虑模型错误、数据漂移、模型监控和重新调整的风险——这也是现代 MLOps 的出现原因。而且,很容易在工程中犯错,使 ML 应用陷入技术债务的沼泽中。(请参阅“机器学习:技术债务的高利贷”以获取有关此主题的精彩阅读。)那么,面对所有这些成本,我们为什么要这么做?
从根本上说,我们考虑 AI 解决方案是因为复杂的模型已经证明能够检测到有价值的模式。这些模式可以是从暗示新分段的客户偏好的聚类到神经网络发现以优化预测的潜在表示。任何给定的机器学习构建都依赖于一个假设或预期,即模型可以检测到可以改进过程、发现可操作的发现或改进有价值的预测的模式。
在定义任何规模的数据科学团队的核心战略假设时,我们可以从这个麦肯锡的示例描述开始,描述 AI 驱动的公司如何以不同的方式思考。来自“赢得 AI 是一种心态”:
如果我们选择了正确的用例,并以正确的方式进行,我们将不断了解我们的客户及其需求,并不断改进我们服务他们的方式。
这是构建数据科学战略时非常有用的视角:它让我们专注于最大化学习,我们需要做的只是确定我们组织对“正确”的定义。但对我们来说,“正确”的用例是什么?
在这里,皮萨诺再次提供了帮助,定义了 R&D 战略的四个要素,这些要素很好地适用于数据科学:
-
架构:我们数据科学职能的组织结构(集中式、分布式)和地理结构。
-
过程:管理我们工作的正式和非正式方式。
-
人员:从我们希望吸引的技能组合到我们的人才价值主张的一切。
-
组合:我们如何在项目类型之间分配资源,以及“排序、优先级和选择项目的标准”。
我们将从最后一个概念开始,重点定义我们组织的理想项目组合,即我们可以说服自己能带来最大价值的组合。鉴于组织之间的巨大差异,我们将从每个组织面临的一个挑战开始:风险。
定义你的目标组合:根据你的战略确定风险水平和管理。
建模工作具有不确定的结果。“机器学习可以做得更好”是我们经常根据历史和直觉提出的论点,且这通常被证明是正确的。但我们从一开始就不知道它的效果如何,直到我们通过实际构建证明机器学习能解决问题。了解任何给定用例的答案可能需要不同程度的努力,从而产生不同程度的成本。这种答案的不确定性也可能不同,取决于我们的模型被应用的广泛程度以及我们对数据的理解程度。
一位朋友和医疗分析产品负责人,约翰·梅纳德将风险定义为数据科学战略中的一个明确部分:“你如何维护一个小规模和大规模的投注管道,同时保持健康的期望?当数据无法支持项目时,你的策略是什么?如果项目未能满足要求,你会如何调整交付内容?”
对于组织来说,明确和具体地了解他们能够承担的资源水平及其时间长度是明智的。以下是对任何个人建模工作提出的一些有用问题:
-
成功的估计可能性:这个模型用例成功的概率是多少?
-
预期回报范围:如果成功,这个项目是否能在一个可以大规模带来巨大回报的过程上带来微小的改进?一个突破是否能让你与竞争对手区分开来?
-
发现失败的预期时间:需要多长时间才能了解一个项目的假设价值主张是否会实现?在了解到这个项目不会成功之前,你能花费的最少资源是多少?
希望这些原则是简单明了的,并且都是共识中的好事。理想的项目可能会成功,带来巨大的投资回报,如果失败,则应尽早失败。这种完美的三位一体往往难以实现。关键在于做出适合你组织的权衡。
一个早期阶段的初创公司,专注于利用人工智能颠覆特定领域,可能会有投资者、领导层和员工接受公司作为对特定方法的单一大型投资。或者,它可能更倾向于那些能够快速进入生产并允许快速调整的小型项目。相反,如果我们在一家大型、成熟的公司和监管严格的行业中,并且利益相关者对机器学习持怀疑态度,我们可能会选择将投资组合偏向于低工作量的项目,这些项目提供渐进的价值并快速失败。这可以帮助建立初步信任,使利益相关者适应数据科学项目固有的不确定性,并使团队围绕更雄心勃勃的项目达成一致。成功的早期小型项目还可以加强对同一问题领域的更大项目的支持。
以下是如何定义目标投资组合的一些示例,包括项目范围、持续时间和预期回报:
-
“由于我们在集体数据科学旅程中仍处于早期阶段,我们专注于小型、低工作量和快速失败的用例,这样可以在不冒大量人员时间风险的情况下发现机会。”
-
“我们已经确定了一个包含三个大型机器学习项目的投资组合,每一个项目都有可能释放巨大的价值。”
-
“我们的目标是平衡小型、中型和大型项目,并与相应的回报水平相匹配。这使我们能够在追求具有颠覆性潜力的同时,频繁获得胜利。”
作为应用于我们完整投资组合的最终原则,目标是一个具有非相关成功的项目集合。也就是说,我们希望看到我们的投资组合,并感知项目将独立成功或失败。如果多个项目依赖于共同的假设,如果我们感到它们如此紧密相关以至于它们会一起成功或失败,那么我们应该重新考虑选择。
当我们完成以下任务时,我们就完成了这个阶段:
-
调查了我们的数据科学和机器学习机会
-
按投资、回报和成功可能性进行了绘制
-
选择了一个与我们的目标和风险承受能力一致的粗略优先级列表
现在我们已经确定了目标投资组合,我们将转向确保我们的流程使我们能够快速识别、范围定义和交付有价值的项目。
将你的投资组合集中在你的团队独特的解决方案上
建设还是购买的问题是一个长期存在的问题,并且通常涉及复杂的组织动态。寻找 AI 解决方案的供应商和初创公司不乏其人。许多是“江湖术士”;许多是有效的。许多内部技术和数据科学团队将前者视为笑话,将后者视为竞争对手,并且将区分两者所花费的时间视为巨大的时间浪费。这是有道理的,因为检查供应商的时间并不能提升建模者的技能,如果组织不奖励他们的努力,那么这就是数据科学家付出的成本,而没有职业上的回报。这种人际关系复杂性加剧了本已复杂的商业案例:所有典型的软件解决方案关注点仍然存在。你仍然需要担心供应商锁定和云集成等问题。然而,我们都应该愿意购买那些提供更高投资回报率的供应商产品,如果你考虑到内部团队相对于现成解决方案的独特优势,你可以排除干扰。
特别是,你的内部团队通常可以访问你组织的大部分(可能是全部)专有数据。这意味着内部团队可能会比单一用途的供应商解决方案更深入地理解这些数据,并更容易将其与其他来源进行丰富。只要有足够的时间和计算资源,一个有能力的内部团队可能会比单一用途的供应商解决方案更胜一筹。(这里面有个 PAC 理论的笑话。)但这值得吗?
标准的投资回报率和替代方案分析在这里是关键,重点是你内部市场的时间。假设我们正在优化一个电子商务网站上的广告投放。我们已经将供应商名单缩小到一个领先者,该供应商使用的是一种多臂赌博方法,这是在撰写本文时领先的市场优化供应商中常见的方法。我们估计供应商集成的时间为一个月。或者,我们可以建立自己的 MAB,并估计需要六个月。我们是否期望自己构建的 MAB 能够超越供应商的 MAB,并且值得为了这个延迟而付出?
这要看情况。使用汤普森采样进行多臂赌博问题(MAB)可以为预期的遗憾提供对数界限,这是一种行话,意味着它在探索选项时不会留下一大堆未利用的价值。无论是由你的内部团队还是供应商实施,这一说法仍然可以证明是正确的。相反,你的内部团队更接近你的数据,将这种用例带到内部相当于一个赌注,即你会在数据中找到足够丰富的信号来超越供应商产品。也许你的团队可以注入现成解决方案所没有的领域知识,从而提供有价值的优势。最后,考虑你的内部团队的机会成本:是否有其他高价值的项目他们可以从事?如果有,一个选项是测试供应商,处理其他项目,并在获得可衡量的供应商结果后重新评估。
当我们完成以下任务时,这一阶段就结束了:
-
回顾之前步骤中的机会,并对每个机会回答,“我们能买到这个吗?”
-
对于每个可购买的解决方案,回答我们是否有独特的已知或假设的内部优势。
-
对于每个需要做出真实权衡的领域,进行权衡分析。
确定了我们内部团队的战略竞争优势后,我们现在将考虑我们的内部流程、工具和数据能力。
在你的知识工厂工具和数据供应链周围建立流程。
我与许多经验丰富的数据科学家讨论过时间投入的问题,每个人都提到发现、处理、清理和移动(到适当的计算环境)数据占据了他们大部分工作时间。正如另一组麦肯锡作者在 AutoML 和 AI 人才战略上写道:“许多组织发现,数据科学家花费 60%到 80%的时间来准备建模数据。一旦初步模型构建完成,根据一些分析,只有 4%的时间用于测试和调优代码。”这并不是大多数人进入这一领域的原因。在我们大多数人的心中,这是一种为了构建有影响力的模型的乐趣而付出的代价。因此,我们常常谈论数据科学家成功所需的“基础”。根据我的经验,这种框架很快会阻碍我们,我将挑战我们将自己视为一个模型工厂,受到工具和复杂且常常有问题的数据供应链的限制。
坦白说:当讨论平台时,我从未相信这些“基础”论点。
“数据和机器学习平台是成功的机器学习所依赖的基础,”这是无数幻灯片和白皮书中的明确声明。 “如果没有强大的基础,”某些顾问父爱般地总结道,“一切都会崩溃。”
然而,这里有个问题:很少有事情会因为没有机器学习而“崩溃”。如果你的房子建在不好的基础上,车库可能会塌陷在自己身上,甚至你也会受害。如果你在没有完善的数据和机器学习平台的情况下开始机器学习项目,那么你的模型构建将会……需要更长时间。而且没有那种新奇的机器学习模型,你的业务很可能会继续以原有方式运行,尽管缺少了机器学习原本旨在提供的某种竞争优势。但在平庸中持续并非末日。
这就是这个陈词滥调让我感到困惑的地方。它试图吓唬高管们投入平台建设——值得强调的是有价值的建设——好像没有它们世界就会末日,但其实不会。我们大声喊着天空要塌下来,然后当一个利益相关者遇到他们习惯的老雨时,我们失去了信誉。
尽管如此,我敢打赌,拥有强大机器学习能力的公司将超越那些没有这种能力的竞争者——我明白作为建模负责人我的职业生涯正是这样一种赌注——而现代数据和 MLOps 能力可以大大缩短 AI 能力的市场时间。请参见麦肯锡论文中的摘录 “像科技本地人一样扩展 AI: CEO 的角色”,强调由我加的:
我们经常听到高管们说,将 AI 解决方案 从构想到实施需要九个月到一年以上*,使得很难跟上市场动态的变化。即便经过多年的投资,领导者们经常告诉我们他们的组织并没有加快速度。相比之下,* 采用 MLOps 的公司可以在仅仅两到 12 周内将构思转变为实际解决方案 而不会增加人员或技术债务,减少了实现价值的时间,并释放团队以更快地扩展 AI。
你的数据科学战略需要考虑你的组织和工具约束,并采用在这些约束下能产生可操作模型或知识单元的模式。也就是说,建模项目应该总是包含:
-
清晰地了解最小可行建模数据。你的数据科学团队应该知道源数据的位置,并且对数据需要如何转化有一个大致的了解。
-
实现价值的直接和现实路径。你将如何让一个具有足够性能的模型投入使用,或者以其他方式应用模型结果?
早期阶段的公司或团队如果在架构和工具方面拥有完全的自由,将有利于采用现代 MLOps 实践,这将使得快速原型设计、部署和监控模型以评估其在实际世界中影响变得更加容易。与传统遗留技术并行工作或在其中工作的团队可能会发现这些技术并未考虑 ML 集成,并且部署是一个庞大的、沉重的过程。受严格监管行业中的公司将发现许多应用程序需要高度的可解释性和风险控制。
这些挑战没有不可克服的。我们只需在时间表的影响上保持原则性和聪明,并将其纳入我们的决策中。
当我们完成这个阶段时,我们会有:
-
调查我们计划中的用例,以确定每个用例获取数据的路径以便开始。
-
确定每个用例的实现价值路径,如果它能够成功的话。
-
将这一点纳入我们的预期投资中,并从第一步开始进行调整。
-
根据我们发现的任何变化调整了我们的优先级。
在完善了我们关于数据科学部署的想法后,我们将考虑工作模型以确保一致性。
架构与组织:构建一个能够持续成功的组织结构。
Pisano 将架构定义为“围绕研发在组织和地理上的结构所做的一系列决策。” 设计这点包括对如何将我们的数据科学家与业务部门整合做出深思熟虑的决定。他们是完全集中并有正式的接收流程?向不同的业务单位汇报?集中并嵌入其中?汇报结构和决策权可能不在你的掌控之下,特别是当你被要求为有明确汇报线的部门建立战略时。但如果这些点正在讨论中,这里有一些最大化数据科学成果价值的考虑因素。
你的数据科学家会得到良好的支持并被适当衡量吗? 考虑一下初级数据科学人才的来源。数据科学家来自各种定量背景,通常具有理论和实践技能的混合。一个典型的硕士毕业生在这些形成阶段中建立了技能和理解,并向领域专家展示了这些理解。这通常不包括大量的与非专家沟通技术发现的培训。
与他们在商业环境中的经历相比,他们可能对领域的了解较少,并且是少数拥有方法知识的人之一。他们将被要求应用少数人理解的技术。他们的项目必然比标准软件构建包含更多的不确定性。他们的成功将依赖于更多因素,许多因素超出数据科学家的控制范围,他们在阐述要求以最大化成功机会方面经验有限。将所有这些因素综合考虑,我们开始看到一种被投入深水区的情况。
这可能导致其他职能领导在首次领导数据科学团队时面临挑战。这一教训来自麦肯锡的“为现代时代建立研发战略”,也适用于我们的领域:
组织倾向于青睐那些有近期回报的“安全”项目——比如那些源于客户需求的项目——这些项目在许多情况下只是维持现有市场份额。例如,一家消费品公司将研发预算划分给其业务单位,业务单位的领导者则用这些资金来达成短期目标,而不是公司的长期差异化和增长目标。
在我们的领域,这通常表现为初级数据科学家被他们的非技术主管要求编写能够回答当天问题的任何 SQL 查询。这通常是有帮助的,但通常不是企业通过招聘精明建模师来驱动的价值。
当你有曾经管理过数据科学(DS)或机器学习(ML)项目的领导时,这个问题会更容易解决。无论职能如何,成功的关键在于拥有能够倾听问题、规划分析和建模方法解决问题,并管理风险和不确定性的人。许多早期职业的数据科学家在这种情况下表现出色。根据我的经验,他们是沟通能力和处理模糊性的天赋者。我有幸不小心聘用了一些这样的人——嗨,志宇!依赖你的能力来筛选这些人才,并为之竞争,可能会带来风险。
这一切似乎都支持将数据科学职能集中化。这是一种方法,也引出了我们下一个重要的问题。
你的数据科学家是否足够接近业务,以关注正确的问题? 与直接向业务团队汇报的超本地团队相比,中央数据科学职能组可能会较少接触到你希望解决的业务问题。大型、单一的职能团队通过正式的流程,可能很难获得所需的业务输入,主要是因为许多利益相关者不知道自己要提出什么问题。如果你听过数据科学团队产生“没有人要求的科学项目”的恐怖故事,这通常是一个根本原因。而且,再次提醒,不要刻板印象:这很少是因为数据科学团队有过于学术的思维方式,更常见的是因为两个不同职能不知如何用共同语言交流。
这留给我们什么选项?这也是我经验中嵌入式模型有效的一个原因。在这种模型中,你的数据科学团队可以访问你们经常讨论业务问题的所有论坛。他们负责利用这个机会理解业务团队希望解决的问题,并提出可以增加价值的方法。他们向数据科学领导汇报,后者确保他们的工作方法论是正确的,支持他们获得项目成功所需的资源,并指导和辅导他们的成长。
有时数据科学项目失败是因为方法论不佳;它们常常失败是因为预测特征不够有用。知道这两者之间的区别对于非定量职能的人来说可能非常困难。
当我们完成这一步时,我们需要:
-
清晰地定义数据科学家或团队的工作范围
-
定义的参与模式
正如所有实际决策中一样,到处都有权衡,没有万能的解决方案。完全自主的本地团队将最大限度地关注不同的本地结果。集中式职能将最小化重复,但增加了偏离实际、有影响的结果的风险。
退后一步,进行沟通和整体迭代
让我们回顾一下我们迄今为止取得的成就:
-
定义了一个战略假设,即我们将如何通过数据科学和机器学习增加价值的大赌注。
-
确定一个目标投资组合,该投资组合与我们组织的风险承受能力相一致,考虑到你的流程和技术限制,并将我们的团队集中于那些无法通过购买解决的问题上。
-
根据数据访问和它们如何创造价值,筛选我们的使用案例。
-
可能,开发了支持数据科学家的报告结构项目采购方法,并将他们的才能集中于他们独特的优势。
更直白地说,我们已经列出了找到正确使用案例的标准,并筛选了我们的使用案例机会,以找到第一个正确的集合。
接下来的任务是:
-
退后一步,整体查看所有内容。作为一个整体来看,这是否合理?
-
传达这一策略,以及从中衍生出的初步计划。
-
传达潜在利益相关者如何参与你的职能团队。
-
迭代:每当导致策略的假设或情况发生变化时,重新审视你的策略,并承诺定期检查情况的变化。
总结来说,这个过程需要相当大的努力。但是,它带来了巨大的回报。这一策略将明确表达你想要承担的风险、如何管理这些风险,以及它们如何支持你的目标结果(如果成功的话)。目的的明确对齐,以及保持活动与这一目的的一致性,对于一个职能团队来说是非常赋能的。实现这一点,结果将随之而来。
参考文献
-
Brenna 等人,“为现代时代构建研发战略”
-
Corbo 等人,“像技术原生公司一样扩展 AI:CEO 的角色”
-
Kiechel, Walter. 战略之王:新企业世界的秘密知识史(附属链接。)
-
Meakin 等人,“用 AI 获胜是一种心态”
-
Pisano, Gary P. “制定研发战略”
-
Rasiel, Ethan. 麦肯锡方法(附属链接。)
-
Scully 等人,“机器学习:技术债务的高利息信用卡”
在你的 WhatsApp 聊天中构建一个语言模型
通过应用程序了解 GPT 架构的视觉指南
·
关注 发表在 Towards Data Science ·16 分钟阅读·2023 年 11 月 21 日
–
图片由 Volodymyr Hryshchenko 提供,来自 Unsplash
聊天机器人无疑改变了我们与数字平台的互动。尽管基础语言模型在处理复杂任务方面取得了令人印象深刻的进展,但用户体验仍然常常显得不够个人化和疏离。
为了使对话更加贴近实际,我设想了一个可以模拟我随意写作风格的聊天机器人,就像在 WhatsApp 上给朋友发短信一样。
在这篇文章中,我将带你了解我构建一个(小型)语言模型的过程,该模型生成合成对话,使用我的 WhatsApp 聊天记录作为输入数据。过程中,我尝试以可视化且希望易于理解的方式解开 GPT 架构的内部工作机制,并附有实际的 Python 实现。你可以在我的 GitHub 上找到完整项目。
注意: 模型类 本身大块取自 Andrej Karpathy 的视频系列,并根据我的需要进行了调整。我强烈推荐他的教程。
从头开始训练一个语言模型,完全基于你的 WhatsApp 群聊。
github.com](https://github.com/bernhard-pfann/lad-gpt?source=post_page-----31264a9ced90--------------------------------)
目录
-
选定的方法
-
数据来源
-
分词
-
索引
-
模型架构
-
模型训练
-
聊天模式
1. 选定的方法
在将语言模型定制为特定语料库时,可以采取几种方法:
-
模型构建: 这涉及从头开始构建和训练模型,在模型架构和训练数据选择方面提供了最大的灵活性。
-
微调: 这种方法利用现有的预训练模型,调整其权重以更紧密地与手头的特定数据对齐。
-
提示工程: 这也利用了现有的预训练模型,但在这里,独特的语料库直接融入提示中,而不改变模型的权重。
由于我对这个项目的动机主要是自我教育,并且对现代语言模型的架构非常感兴趣,我选择了第一种方法。然而,这种选择带来了明显的限制。鉴于我的数据量和可用计算资源,我并没有期望与任何最先进的预训练模型相当的结果。
尽管如此,我仍希望我的模型能发现一些有趣的语言模式,最终它确实做到了。探索第二种选项(微调)可能会成为未来文章的重点。
2. 数据来源
WhatsApp,作为我的主要沟通渠道,是捕捉我的对话风格的理想来源。导出超过六年的群聊记录,总计超过 150 万字是非常简单的。
数据使用正则表达式模式解析成包含日期、联系人姓名和聊天消息的元组列表。
pattern = r'\[(.*?)\] (.*?): (.*)'
matches = re.findall(pattern, text)
text = [(x1, x2.lower()) for x0, x1, x2 in matches]
[
(2018-03-12 16:03:59, "Alice", "Hi, how are you guys?"),
(2018-03-12 16:05:36, "Tom", "I am good thanks!"),
...
]
现在,每个元素都被单独处理。
-
发送日期: 除了将其转换为日期时间对象外,我没有利用这些信息。然而,可以查看时间差异以区分话题讨论的开始和结束。
-
联系人姓名: 在分词文本时,每个联系人姓名被视为一个唯一的标记。这确保了名和姓的组合仍被视为一个整体。
-
聊天消息: 在每条消息的末尾添加了一个特殊的“”标记。
3. 分词
为了训练语言模型,我们需要将语言分解成片段(所谓的标记),并逐步输入模型。分词可以在多个层次上进行。
-
字符级别: 文本被视为一系列单独的字符(包括空格)。这种细粒度的方法允许从字符序列中形成每一个可能的单词。然而,捕捉单词之间的语义关系会更困难。
-
词级别: 文本被表示为一个单词序列。然而,模型的词汇量受到训练数据中现有单词的限制。
-
子词级别: 文本被拆分成比单词小但比字符大的子词单元。
虽然我最初使用的是字符级别的分词器,但我觉得训练时间被浪费在了学习重复单词的字符序列上,而不是关注句子中单词之间的语义关系。
为了概念上的简洁,我决定切换到词级别的分词器,暂时搁置了用于更复杂分词策略的现有库。
from nltk.tokenize import RegexpTokenizer
def custom_tokenizer(txt: str, spec_tokens: List[str], pattern: str="|\d|\\w+|[^\\s]") -> List[str]:
"""
Tokenize text into words or characters using NLTK's RegexpTokenizer, considerung
given special combinations as single tokens.
:param txt: The corpus as a single string element.
:param spec_tokens: A list of special tokens (e.g. ending, out-of-vocab).
:param pattern: By default the corpus is tokenized on a word level (split by spaces).
Numbers are considered single tokens.
>> note: The pattern for character level tokenization is '|.'
"""
pattern = "|".join(spec_tokens) + pattern
tokenizer = RegexpTokenizer(pattern)
tokens = tokenizer.tokenize(txt)
return tokens
["Alice:", "Hi", "how", "are", "you", "guys", "?", "<END>", "Tom:", ... ]
结果显示我的训练数据有大约 70,000 个独特的单词。然而,由于许多单词仅出现一次或两次,我决定用“”特殊标记替代这些稀有单词。这减少了词汇量至大约 25,000 个单词,从而得到一个较小的模型,后续训练也会更简单。
from collections import Counter
def get_infrequent_tokens(tokens: Union[List[str], str], min_count: int) -> List[str]:
"""
Identify tokens that appear less than a minimum count.
:param tokens: When it is the raw text in a string, frequencies are counted on character level.
When it is the tokenized corpus as list, frequencies are counted on token level.
:min_count: Threshold of occurence to flag a token.
:return: List of tokens that appear infrequently.
"""
counts = Counter(tokens)
infreq_tokens = set([k for k,v in counts.items() if v<=min_count])
return infreq_tokens
def mask_tokens(tokens: List[str], mask: Set[str]) -> List[str]:
"""
Iterate through all tokens. Any token that is part of the set, is replaced by the unknown token.
:param tokens: The tokenized corpus.
:param mask: Set of tokens that shall be masked in the corpus.
:return: List of tokenized corpus after the masking operation.
"""
return [t.replace(t, unknown_token) if t in mask else t for t in tokens]
infreq_tokens = get_infrequent_tokens(tokens, min_count=2)
tokens = mask_tokens(tokens, infreq_tokens)
["Alice:", "Hi", "how", "are", "you", "<UNK>", "?", "<END>", "Tom:", ... ]
4. 索引
在分词之后,下一步是将单词和特殊标记转换为数值表示。使用固定的词汇表,每个单词按其位置进行了索引。编码后的单词随后被准备为 PyTorch 张量。
import torch
def encode(s: list, vocab: list) -> torch.tensor:
"""
Encode a list of tokens into a tensor of integers, given a fixed vocabulary.
When a token is not found in the vocabulary, the special unknown token is assigned.
When the training set did not use that special token, a random token is assigned.
"""
rand_token = random.randint(0, len(vocab))
map = {s:i for i,s in enumerate(vocab)}
enc = [map.get(c, map.get(unknown_token, rand_token)) for c in s]
enc = torch.tensor(enc, dtype=torch.long)
return enc
torch.tensor([8127, 115, 2363, 3, ..., 14028])
由于我们需要评估模型在一些未见数据上的质量,我们将张量分成两部分。这样,我们就得到了训练集和验证集,可以准备好喂给语言模型。
作者提供的图片
5. 模型架构
我决定应用 GPT 架构,这一架构在具有影响力的论文“Attention is All you Need”中得到了推广。由于我试图构建的是语言生成器而不是问答机器人,因此仅使用解码器(右侧)架构足以满足这一目的。
“Attention is All you Need” 由 A. Vaswani 等人撰写(取自 arXiv: 1706.03762)
在接下来的部分中,我将分解 GPT 架构的每个组件,解释其作用以及基础的矩阵运算。从准备好的训练测试开始,我将追踪一个示例上下文的 3 个词,通过模型,直到它预测下一个令牌。
5.1. 模型目标
在深入技术细节之前,了解我们模型的主要目标至关重要。在仅解码器的设置中,我们的目标是解码语言的结构,以准确预测序列中的下一个令牌,前提是给定前面的令牌上下文。
图片由作者提供
当我们将索引的令牌序列输入模型时,它会经历一系列与各种权重矩阵的矩阵乘法运算。输出是一个向量,表示每个令牌在序列中作为下一个令牌的概率,这个概率基于输入上下文。
模型评估:
我们的模型性能通过训练数据来评估,其中实际的下一个令牌是已知的。目标是最大化正确预测这个下一个令牌的概率。
然而,在机器学习中,我们常常关注“损失”这一概念,它量化了错误或不正确预测的可能性。为了计算这一点,我们将模型的输出概率与实际的下一个令牌进行比较(使用cross-entropy)。
优化:
通过了解当前的损失,我们旨在通过反向传播来最小化它。这个过程涉及迭代地将令牌序列输入模型,并调整权重矩阵以提升性能。
在每张图中,我将用黄色标出在该过程中将被优化的权重矩阵。
5.2. 输出嵌入
到目前为止,我们序列中的每个令牌都由一个整数索引表示。然而,这种简单的形式无法反映单词之间的关系或相似性。为了解决这个问题,我们将一维索引提升到更高维度的空间中,通过嵌入实现。
-
词嵌入: 单词的本质由一个 n 维的浮点向量来捕捉。
-
位置嵌入: 这些嵌入强调了单词在句子中的位置的重要性,也表示为 n 维的浮点向量。
对于每个令牌,我们查找其词嵌入和位置嵌入,然后逐元素相加。这就得出了每个令牌在上下文中的输出嵌入。
在下面的示例中,上下文包含 3 个令牌。在嵌入过程结束时,每个令牌由一个 n 维向量表示(其中 n 是嵌入大小,一个可调的超参数)。
图片由作者提供
PyTorch 提供了专门的类来处理这些嵌入。在我们的模型类中,我们定义了词嵌入和位置嵌入,如下所示(传递矩阵维度作为参数):
self.word_embedding = nn.Embedding(vocab_size, embed_size)
self.pos_embedding = nn.Embedding(block_size, embed_size)
5.3. 自注意力头
虽然词嵌入提供了词语相似性的整体感觉,但一个词的真实含义往往取决于其周围的上下文。例如,“bat”可能指的是动物或运动器材,这取决于句子。这就是自注意力机制(GPT 架构的关键组成部分)发挥作用的地方。
自注意力机制关注三个主要概念:查询(Q)、键(K)和值(V)。
-
查询(Q): 查询本质上是当前标记的表示,注意力需要计算它。就像在问:“作为当前标记,我应该关注上下文中的什么?”
-
键(K): 键是输入序列中每个标记的表示。它们与查询配对,以确定注意力分数。这种比较衡量了查询标记应将多少关注放在上下文中的其他标记上。高分意味着应该更多关注。
-
值(V): 值也是输入序列中每个标记的表示。然而,它们的作用不同,因为它们对注意力分数施加最终加权。
作者提供的图像
示例:
在我们的示例中,上下文中的每个标记已经是嵌入形式,作为 n 维向量(e1, e2, e3)。自注意力头将它们作为输入,以逐一输出每个标记的上下文化版本。
-
在评估“name”这个标记时,通过将其嵌入向量v2与可训练矩阵M_Q相乘,得到一个查询向量q。
-
同时,为上下文中的每个标记计算键向量**(k1, k2, k3),通过将每个嵌入向量(e1, e2, e3)与可训练矩阵M_K**相乘。
-
值向量**(v1, v2, v3)** 以相同的方式获得,只是乘以不同的可训练矩阵M_V。
-
注意力分数w通过查询向量与每个键向量之间的点积来计算。
-
最后,我们将所有值向量堆叠成一个矩阵,并将其与注意力分数相乘,以获得标记“name”的上下文化向量。
class Head(nn.Module):
"""
This module performs self-attention operations on the input tensor, producing
an output tensor with the same time-steps but different channels.
:param head_size: The size of the head in the multi-head attention mechanism.
"""
def __init__(self, head_size):
super().__init__()
self.key = nn.Linear(embed_size, head_size, bias=False)
self.query = nn.Linear(embed_size, head_size, bias=False)
self.value = nn.Linear(embed_size, head_size, bias=False)
def forward(self, x):
"""
# input of size (batch, time-step, channels)
# output of size (batch, time-step, head size)
"""
B,T,C = x.shape
k = self.key(x)
q = self.query(x)
# compute attention scores
wei = q @ k.transpose(-2,-1)
wei /= math.sqrt(k.shape[-1])
# avoid look-ahead
tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float("-inf"))
wei = F.softmax(wei, dim=-1)
# weighted aggregation of the values
v = self.value(x)
out = wei @ v
return out
5.4. 掩蔽多头注意力
语言是复杂的,捕捉其所有的细微差别并不简单。一组注意力计算通常不足以捕捉词语如何相互作用的细微之处。这就是 GPT 模型中的多头注意力的理念派上用场的地方。
你可以把多头注意力想象成多个“眼睛”以不同的方式观察数据,每个“眼睛”注意到独特的细节。这些独立的观察结果然后被整合成一个大图景。为了使这个大图景易于管理并与我们模型的其余部分兼容,我们使用线性层(可训练权重)将其压缩回原始的嵌入大小。
最后,为了确保我们的模型不仅仅记住训练数据,还能在新文本上进行良好的预测,我们使用了一个 dropout 层。这个层在训练过程中随机关闭数据的部分内容,帮助模型变得更加适应。
作者提供的图片
class MultiHeadAttention(nn.Module):
"""
This class contains multiple `Head` objects, which perform self-attention
operations in parallel.
"""
def __init__(self):
super().__init__()
head_size = embed_size // n_heads
heads_list = [Head(head_size) for _ in range(n_heads)]
self.heads = nn.ModuleList(heads_list)
self.linear = nn.Linear(n_heads * head_size, embed_size)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
heads_list = [h(x) for h in self.heads]
out = torch.cat(heads_list, dim=-1)
out = self.linear(out)
out = self.dropout(out)
return out
5.5. 前馈
多头注意力层最初捕捉了序列中的上下文关系。通过两个连续的线性层为网络添加了更多深度,这两个层共同构成了前馈神经网络。
作者提供的图片
在初始线性层中,我们增加了维度(在我们的例子中是增加了 4 倍),这有效地拓宽了网络学习和表示更复杂特征的能力。对结果矩阵的每个元素应用ReLU函数,使得非线性模式能够被识别。
随后,第二个线性层作为一个压缩器,将扩展的维度减少回原始形状(块大小 x 嵌入大小)。Dropout 层结束了这个过程,随机地停用矩阵的部分元素,以实现模型的泛化。
class FeedFoward(nn.Module):
"""
This module passes the input tensor through a series of linear transformations
and non-linear activations.
"""
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(embed_size, 4 * embed_size),
nn.ReLU(),
nn.Linear(4 * embed_size, embed_size),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
5.6. Add & Norm
现在,我们通过引入两个关键元素,将多头注意力和前馈组件连接在一起:
-
残差连接(Add): 这些连接执行层输出与其未更改输入的逐元素加法。在训练过程中,模型根据层变换的有用性调整对层变换的重视程度。如果某个变换被认为是不必要的,其权重和相应的层输出将趋向于零。在这种情况下,至少未更改的输入会通过残差连接传递。这种技术有助于缓解梯度消失问题。
-
层归一化(Norm): 这种方法通过减去嵌入向量的均值并除以其标准差来归一化上下文中的每个嵌入向量。这个过程还确保了在反向传播过程中梯度不会爆炸或消失。
作者提供的图片
多头注意力和前馈层的链条,通过“Add & Norm”连接,合并成一个块。这种模块化设计使我们能够形成一系列块。这些块的数量是一个超参数,它决定了模型架构的深度。
class Block(nn.Module):
"""
This module contains a single transformer block, which consists of multi-head
self-attention followed by feed-forward neural networks.
"""
def __init__(self):
super().__init__()
self.sa = MultiHeadAttention()
self.ffwd = FeedFoward()
self.ln1 = nn.LayerNorm(embed_size)
self.ln2 = nn.LayerNorm(embed_size)
def forward(self, x):
x = x + self.sa(self.ln1(x))
x = x + self.ffwd(self.ln2(x))
return x
5.7. Softmax
在遍历多个块组件后,我们获得了一个维度为(块大小 x 嵌入大小)的矩阵。为了将这个矩阵重塑为所需的维度(块大小 x 词汇表大小),我们将其通过一个最终的线性层。这个形状表示了上下文中每个位置词汇表中每个词的一个条目。
最后,我们对这些值应用 soft-max 变换,将其转换为概率。我们成功地获得了上下文中每个位置的下一个标记的概率分布。
6. 模型训练
为了训练语言模型,我从训练数据中的随机位置选择了令牌序列。鉴于 WhatsApp 对话的快节奏,我确定 32 个词的上下文长度足够。因此,我选择了随机的 32 词块作为上下文输入,并使用相应的向量(向后移动一个词)作为比较的目标。
训练过程循环执行以下步骤:
-
对多个批次的上下文进行采样。
-
将这些样本输入模型,以计算当前损失。
-
根据当前的损失和模型权重应用反向传播。
-
每 500 次迭代更全面地评估损失。
一旦所有其他模型超参数(如嵌入大小、自注意力头数量等)确定后,我最终选择了一个具有 250 万参数的模型。考虑到我对输入数据大小和计算资源的限制,我发现这是对我而言的最佳设置。
训练过程大约花费了 12 小时,完成了 10,000 次迭代。可以看到,训练本可以更早停止,因为验证集和训练集上的损失差距在扩大。
作者提供的图片
import json
import torch
from config import eval_interval, learn_rate, max_iters
from src.model import GPTLanguageModel
from src.utils import current_time, estimate_loss, get_batch
def model_training(update: bool) -> None:
"""
Trains or updates a GPTLanguageModel using pre-loaded data.
This function either initializes a new model or loads an existing model based
on the `update` parameter. It then trains the model using the AdamW optimizer
on the training and validation data sets. Finally the trained model is saved.
:param update: Boolean flag to indicate whether to update an existing model.
"""
# LOAD DATA -----------------------------------------------------------------
train_data = torch.load("assets/output/train.pt")
valid_data = torch.load("assets/output/valid.pt")
with open("assets/output/vocab.txt", "r", encoding="utf-8") as f:
vocab = json.loads(f.read())
# INITIALIZE / LOAD MODEL ---------------------------------------------------
if update:
try:
model = torch.load("assets/models/model.pt")
print("Loaded existing model to continue training.")
except FileNotFoundError:
print("No existing model found. Initializing a new model.")
model = GPTLanguageModel(vocab_size=len(vocab))
else:
print("Initializing a new model.")
model = GPTLanguageModel(vocab_size=len(vocab))
# initialize optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learn_rate)
# number of model parameters
n_params = sum(p.numel() for p in model.parameters())
print(f"Parameters to be optimized: {n_params}\n", )
# MODEL TRAINING ------------------------------------------------------------
for i in range(max_iters):
# evaluate the loss on train and valid sets every 'eval_interval' steps
if i % eval_interval == 0 or i == max_iters - 1:
train_loss = estimate_loss(model, train_data)
valid_loss = estimate_loss(model, valid_data)
time = current_time()
print(f"{time} | step {i}: train loss {train_loss:.4f}, valid loss {valid_loss:.4f}")
# sample batch of data
x_batch, y_batch = get_batch(train_data)
# evaluate the loss
logits, loss = model(x_batch, y_batch)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
torch.save(model, "assets/models/model.pt")
print("Model saved")
7. 聊天模式
为了与训练好的模型进行交互,我创建了一个函数,允许通过下拉菜单选择联系人姓名,并输入消息供模型响应。参数“n_chats”决定模型一次生成的响应数量。当模型预测标记为下一个标记时,模型结束生成消息。
import json
import random
import torch
from prompt_toolkit import prompt
from prompt_toolkit.completion import WordCompleter
from config import end_token, n_chats
from src.utils import custom_tokenizer, decode, encode, print_delayed
def conversation() -> None:
"""
Emulates chat conversations by sampling from a pre-trained GPTLanguageModel.
This function loads a trained GPTLanguageModel along with vocabulary and
the list of special tokens. It then enters into a loop where the user specifies
a contact. Given this input, the model generates a sample response. The conversation
continues until the user inputs the end token.
"""
with open("assets/output/vocab.txt", "r", encoding="utf-8") as f:
vocab = json.loads(f.read())
with open("assets/output/contacts.txt", "r", encoding="utf-8") as f:
contacts = json.loads(f.read())
spec_tokens = contacts + [end_token]
model = torch.load("assets/models/model.pt")
completer = WordCompleter(spec_tokens, ignore_case=True)
input = prompt("message >> ", completer=completer, default="")
output = torch.tensor([], dtype=torch.long)
print()
while input != end_token:
for _ in range(n_chats):
add_tokens = custom_tokenizer(input, spec_tokens)
add_context = encode(add_tokens, vocab)
context = torch.cat((output, add_context)).unsqueeze(1).T
n0 = len(output)
output = model.generate(context, vocab)
n1 = len(output)
print_delayed(decode(output[n0-n1:], vocab))
input = random.choice(contacts)
input = prompt("\nresponse >> ", completer=completer, default="")
print()
结论:
由于个人聊天的隐私,我无法在这里展示示例提示和对话。
尽管如此,你可以期待这样规模的模型能够成功学习句子的总体结构,产生有意义的输出,尤其是在词序方面。在我的案例中,它也掌握了训练数据中某些重要话题的上下文。例如,由于我的个人聊天经常围绕网球展开,网球运动员的名字和与网球相关的词通常会一起输出。
然而,在评估生成句子的连贯性时,我承认结果并没有达到我已经很低的期望。当然,我也可以责怪我的朋友们聊了太多无聊的话,限制了模型学习有用内容的能力…
为了在结尾展示至少一些示例输出,你可以查看虚拟模型在 200 条训练虚拟消息上的表现 😉
作者提供的图片