TowardsDataScience 2023 博客中文翻译(五十九)

原文:TowardsDataScience

协议:CC BY-NC-SA 4.0

使用 Athena 和 MySQL 构建批量数据管道

原文:towardsdatascience.com/building-a-batch-data-pipeline-with-athena-and-mysql-7e60575ff39c

初学者的端到端教程

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 💡Mike Shakhomirov

·发布于 Towards Data Science ·16 min 阅读·2023 年 10 月 20 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片由 Redd F 提供,来自 Unsplash

在这个故事中,我将讲述一种非常流行的数据转换任务执行方式——批量数据处理。当我们需要以块状方式处理数据时,这种数据管道设计模式变得极其有用,非常适合需要调度的 ETL 作业。我将通过使用 MySQL 和 Athena 构建数据转换管道来展示如何实现这一目标。我们将使用基础设施即代码在云中部署它。

想象一下,你刚刚作为数据工程师加入了一家公司。他们的数据堆栈现代、事件驱动、成本效益高、灵活,并且可以轻松扩展以满足不断增长的数据资源。你数据平台中的外部数据源和数据管道由数据工程团队管理,使用具有 CI/CD GitHub 集成的灵活环境设置。

作为数据工程师,你需要创建一个业务智能仪表板,展示公司收入来源的地理分布,如下所示。原始支付数据存储在服务器数据库(MySQL)中。你想构建一个批量管道,从该数据库中每日提取数据,然后使用 AWS S3 存储数据文件,并使用 Athena 进行处理。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

收入仪表板。图像由作者提供。

批量数据管道

数据管道可以被视为一系列数据处理步骤。由于这些阶段之间的逻辑数据流连接,每个阶段生成的输出作为下一个阶段的输入

只要在点 A 和点 B 之间进行数据处理,就存在数据管道。

数据管道可能因其概念和逻辑性质而有所不同。我之前在这里写过 [1]:

数据管道设计模式

选择合适的架构及其示例

数据管道设计模式

我们希望创建一个数据管道,在以下 步骤 中转换数据:

1. 使用 Lambda 函数将数据从 MySQL 数据库表 myschema.usersmyschema.transactions 提取到 S3 数据湖桶中。

2. 添加一个具有 Athena 资源的状态机节点以启动执行 (arn:aws:states:::athena:startQueryExecution.sync) 并创建一个名为 mydatabase 的数据库

3. 创建另一个数据管道节点以显示 Athena 数据库中的现有表。使用该节点的输出执行所需的数据转换。

如果表不存在,我们希望我们的管道在 Athena 中根据来自数据湖 S3 桶的数据创建它们。我们希望创建两个 外部表,数据来自 MySQL:

  • mydatabase.users (LOCATION ‘s3://<YOUR_DATALAKE_BUCKET>/data/myschema/users/’)

  • mydatabase.transactions (LOCATION ‘s3://<YOUR_DATALAKE_BUCKET>/data/myschema/transactions/’)

然后我们希望创建一个 优化的 ICEBERG 表:

  • mydatabase.user_transactions (‘table_type’=’ICEBERG’, ‘format’=’parquet’) 使用以下 SQL:
SELECT 
      date(dt) dt
    , user_id
    , sum(total_cost) total_cost_usd
    , registration_date
  FROM mydatabase.transactions 
  LEFT JOIN mydatabase.users
  ON users.id = transactions.user_id
  GROUP BY
    dt
    , user_id
    , registration_date
;
  • 我们还将使用 MERGE 来更新此表。

MERGE 是一种非常有用的 SQL 技巧,用于表中的增量更新。查看我之前的故事 [3] 以获取更高级的示例:

高级 SQL 技巧

从 1 到 10,你的数据仓库技能有多好?

高级 SQL 技巧

Athena 可以通过运行有吸引力的即席 SQL 查询来分析存储在 Amazon S3 中的结构化、非结构化和半结构化数据,无需管理基础设施。

我们不需要加载数据,这使得它成为我们任务的完美选择。

它可以轻松地与 Business Intelligence (BI) 解决方案如 QuickSight 集成以生成报告。

ICEBERG 是一种非常有用且高效的表格格式,多个独立程序可以同时且一致地处理相同的数据集 [2]。我之前在这里写过:

介绍 Apache Iceberg 表

选择 Apache Iceberg 作为数据湖的几个有力理由

介绍 Apache Iceberg 表

MySQL 数据连接器

让我们创建一个 AWS Lambda 函数,它能够在 MySQL 数据库中执行 SQL 查询。

代码非常简单且通用。它可以在任何无服务器应用程序中与任何云服务提供商一起使用。

我们将使用它将收入数据提取到数据湖中。建议的 Lambda 文件夹结构如下所示:

.
└── stack
    ├── mysql_connector
    │   ├── config       # config folder with environment related settings
    │   ├── populate_database.sql  # sql script to create source tables
    │   ├── export.sql   # sql script to export data to s3 datalake
    │   └── app.py       # main application file
    ├── package          # required libraries
    │   ├── PyMySQL-1.0.2.dist-info
    │   └── pymysql
    ├── requirements.txt # required Python modules
    └── stack.zip        # Lambda package

我们将通过 AWS Step Functions 将这个小服务集成到管道中,以便于 编排和可视化

为了创建一个能够从 MySQL 数据库中提取数据的 Lambda 函数,我们需要先为我们的 Lambda 创建一个文件夹。首先创建一个名为 stack的新文件夹,然后在其中创建一个名为mysql_connector` 的文件夹:

mkdir stack
cd stack
mkdir mysql_connector

然后我们可以使用下面的代码(将数据库连接设置替换为你的设置)来创建 app.py

 import os
import sys
import yaml
import logging
import pymysql

from datetime import datetime
import pytz

ENV = os.environ['ENV']
TESTING = os.environ['TESTING']
LAMBDA_PATH = os.environ['LAMBDA_PATH']
print('ENV: {}, Running locally: {}'.format(ENV, TESTING))

def get_work_dir(testing):
    if (testing == 'true'):
        return LAMBDA_PATH
    else:
        return '/var/task/' + LAMBDA_PATH

def get_settings(env, path):
    if (env == 'staging'):
        with open(path + "config/staging.yaml", "r") as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
    elif (env == 'live'):
        with open(path + "config/production.yaml", "r") as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
    elif (env == 'test'):
        with open(path + "config/test.yaml", "r") as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
    else:
        print('No config found')
    return config

work_dir = get_work_dir(TESTING)
print('LAMBDA_PATH: {}'.format(work_dir))
config=get_settings(ENV, work_dir)
print(config)
DATA_S3 = config.get('S3dataLocation') # i.e. datalake.staging.something. Replace it with your unique bucket name.

logger = logging.getLogger()
logger.setLevel(logging.INFO)

# rds settings
rds_host  = config.get('Mysql')['rds_host'] # i.e. "mymysqldb.12345.eu-west-1.rds.amazonaws.com"
user_name = "root"
password = "AmazingPassword"
db_name = "mysql"

# create the database connection outside of the handler to allow connections to be
# re-used by subsequent function invocations.
try:
    conn = pymysql.connect(host=rds_host, user=user_name, passwd=password, db=db_name, connect_timeout=5)

except pymysql.MySQLError as e:
    logger.error("ERROR: Unexpected error: Could not connect to MySQL instance.")
    logger.error(e)
    sys.exit()

logger.info("SUCCESS: Connection to RDS MySQL instance succeeded")

def lambda_handler(event, context):
    processed = 0
    print("")
    try:
        _populate_db()
        _export_to_s3()
    except Exception as e:
        print(e)
    message = 'Successfully populated the database and created an export job.'
    return {
        'statusCode': 200,
        'body': { 'lambdaResult': message }
    }

# Helpers:

def _now():
    return datetime.utcnow().replace(tzinfo=pytz.utc).strftime('%Y-%m-%dT%H:%M:%S.%f')

def _populate_db():
    try:
        # Generate data and populate database:
        fd = open(work_dir + '/populate_database.sql', 'r')
        sqlFile = fd.read()
        fd.close()
        sqlCommands = sqlFile.split(';')
        # Execute every command from the input file
        for command in sqlCommands:
            try:
                with conn.cursor() as cur:
                    cur.execute(command)
                    print('---')
                    print(command)
            except Exception as e:
                print(e)

    except Exception as e:
        print(e)

def _export_to_s3():
    try:
        # Generate data and populate database:
        fd = open(work_dir + '/export.sql', 'r')
        sqlFile = fd.read()
        fd.close()
        sqlCommands = sqlFile.split(';')
        # Execute every command from the input file
        for command in sqlCommands:
            try:
                with conn.cursor() as cur:
                    cur.execute(command.replace("{{DATA_S3}}", DATA_S3))
                    print('---')
                    print(command)
            except Exception as e:
                print(e)

    except Exception as e:
        print(e)

要使用 AWS CLI 部署我们的微服务,请在命令行中运行以下命令(假设你在 ./stack 文件夹中):

# Package Lambda code:
base=${PWD##*/}
zp=$base".zip" # This will return stack.zip if you are in stack folder.
echo $zp

rm -f $zp # remove old package if exists

pip install --target ./package pymysql 

cd package
zip -r ../${base}.zip .

cd $OLDPWD
zip -r $zp ./mysql_connector

确保在运行下一部分之前 AWS Lambda 角色已经存在 — role arn:aws:iam::<your-aws-account-id>:role/my-lambda-role

# Deploy packaged Lambda using AWS CLI:
aws \
lambda create-function \
--function-name mysql-lambda \
--zip-file fileb://stack.zip \
--handler <path-to-your-lambda-handler>/app.lambda_handler \
--runtime python3.12 \
--role arn:aws:iam::<your-aws-account-id>:role/my-lambda-role

# # If already deployed then use this to update:
# aws --profile mds lambda update-function-code \
# --function-name mysql-lambda \
# --zip-file fileb://stack.zip;

我们的 MySQL 实例必须具备 S3 集成,以便 将数据导出到 S3 桶。这可以通过运行以下 SQL 查询实现:

-- Example query
-- Replace table names and S3 bucket location
SELECT * FROM myschema.transactions INTO OUTFILE S3 's3://<YOUR_S3_BUCKET>/data/myschema/transactions/transactions.scv' FIELDS TERMINATED BY ',' LINES TERMINATED BY '\n' OVERWRITE ON;

如何创建 MySQL 实例

我们可以使用 CloudFormation 模板和基础设施即代码来创建 MySQL 数据库。考虑这个 AWS 命令:

aws \
cloudformation deploy \
--template-file cfn_mysql.yaml \
--stack-name MySQLDB \
--capabilities CAPABILITY_IAM

它将使用 cfn_mysql.yaml 模板文件来创建名为 MySQLDB 的 CloudFormation 堆栈。我之前在这里写过有关它的内容 [4]:

## 使用 AWS CloudFormation 创建 MySQL 和 Postgres 实例

数据库从业人员的基础设施即代码

towardsdatascience.com

我们的 cfn_mysql.yaml 应该如下所示:

AWSTemplateFormatVersion: 2010-09-09
Description: >-
  This
  template creates an Amazon Relational Database Service database instance. You
  will be billed for the AWS resources used if you create a stack from this
  template.
Parameters:
  DBUser:
    Default: root
    NoEcho: 'true'
    Description: The database admin account username
    Type: String
    MinLength: '1'
    MaxLength: '16'
    AllowedPattern: '[a-zA-Z][a-zA-Z0-9]*'
    ConstraintDescription: must begin with a letter and contain only alphanumeric characters.
  DBPassword:
    Default: AmazingPassword
    NoEcho: 'true'
    Description: The database admin account password
    Type: String
    MinLength: '8'
    MaxLength: '41'
    AllowedPattern: '[a-zA-Z0-9]*'
    ConstraintDescription: must contain only alphanumeric characters.
Resources:
### Role to output into s3
  MySQLRDSExecutionRole:
    Type: "AWS::IAM::Role"
    Properties:
      AssumeRolePolicyDocument:
        Version: "2012-10-17"
        Statement:
          - Effect: "Allow"
            Principal:
              Service:
                - !Sub rds.amazonaws.com
            Action: "sts:AssumeRole"
      Path: "/"
      Policies:
        - PolicyName: MySQLRDSExecutionPolicy
          PolicyDocument:
            Version: "2012-10-17"
            Statement:
              - Effect: Allow
                Action:
                  - "s3:*"
                Resource: "*"
###

  RDSCluster: 
    Properties: 
      DBClusterParameterGroupName: 
        Ref: RDSDBClusterParameterGroup
      Engine: aurora-mysql
      MasterUserPassword: 
        Ref: DBPassword
      MasterUsername: 
        Ref: DBUser

### Add a role to export to s3
      AssociatedRoles:
        - RoleArn: !GetAtt [ MySQLRDSExecutionRole, Arn ]
###
    Type: "AWS::RDS::DBCluster"
  RDSDBClusterParameterGroup: 
    Properties: 
      Description: "CloudFormation Sample Aurora Cluster Parameter Group"
      Family: aurora-mysql5.7
      Parameters: 
        time_zone: US/Eastern
        ### Add a role to export to s3
        aws_default_s3_role: !GetAtt [ MySQLRDSExecutionRole, Arn ]
        ###
    Type: "AWS::RDS::DBClusterParameterGroup"
  RDSDBInstance1:
    Type: 'AWS::RDS::DBInstance'
    Properties:
      DBClusterIdentifier: 
        Ref: RDSCluster
      # AllocatedStorage: '20'
      DBInstanceClass: db.t2.small
      # Engine: aurora
      Engine: aurora-mysql
      PubliclyAccessible: "true"
      DBInstanceIdentifier: MyMySQLDB
  RDSDBParameterGroup:
    Type: 'AWS::RDS::DBParameterGroup'
    Properties:
      Description: CloudFormation Sample Aurora Parameter Group
      # Family: aurora5.6
      Family: aurora-mysql5.7
      Parameters:
        sql_mode: IGNORE_SPACE
        max_allowed_packet: 1024
        innodb_buffer_pool_size: '{DBInstanceClassMemory*3/4}'
# Aurora instances need to be associated with a AWS::RDS::DBCluster via DBClusterIdentifier without the cluster you get these generic errors 

如果一切顺利,我们将看到 Amazon 账户中出现一个新的堆栈:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

带有 MySQL 实例的 CloudFormation 堆栈。图片由作者提供。

现在我们可以在我们的数据管道中使用这个 MySQL 实例。我们可以在任何 SQL 工具中尝试我们的 SQL 查询,例如 SQL Workbench,以填充表数据。这些表将用于稍后使用 Athena 创建外部表,可以通过 SQL 创建:

CREATE TABLE IF NOT EXISTS
  myschema.users AS
SELECT
  1 AS id,
  CURRENT_DATE() AS registration_date
UNION ALL
SELECT
  2 AS id,
  DATE_SUB(CURRENT_DATE(), INTERVAL 1 day) AS registration_date;

CREATE TABLE IF NOT EXISTS
  myschema.transactions AS
SELECT
  1 AS transaction_id,
  1 AS user_id,
  10.99 AS total_cost,
  CURRENT_DATE() AS dt
UNION ALL
SELECT
  2 AS transaction_id,
  2 AS user_id,
  4.99 AS total_cost,
  CURRENT_DATE() AS dt
UNION ALL
SELECT
  3 AS transaction_id,
  2 AS user_id,
  4.99 AS total_cost,
  DATE_SUB(CURRENT_DATE(), INTERVAL 3 day) AS dt
UNION ALL
SELECT
  4 AS transaction_id,
  1 AS user_id,
  4.99 AS total_cost,
  DATE_SUB(CURRENT_DATE(), INTERVAL 3 day) AS dt
UNION ALL
SELECT
  5 AS transaction_id,
  1 AS user_id,
  5.99 AS total_cost,
  DATE_SUB(CURRENT_DATE(), INTERVAL 2 day) AS dt
UNION ALL
SELECT
  6 AS transaction_id,
  1 AS user_id,
  15.99 AS total_cost,
  DATE_SUB(CURRENT_DATE(), INTERVAL 1 day) AS dt
UNION ALL
SELECT
  7 AS transaction_id,
  1 AS user_id,
  55.99 AS total_cost,
  DATE_SUB(CURRENT_DATE(), INTERVAL 4 day) AS dt
;

使用 Athena 处理数据

现在我们希望添加一个数据管道工作流,该工作流触发我们的 Lambda 函数以从 MySQL 提取数据,将其保存到数据湖中,然后在 Athena 中开始数据转换。

我们希望使用 MySQL 中的数据创建两个外部 Athena 表:

  • myschema.users

  • myschema.transactions

然后我们希望创建一个优化的 ICEBERG 表 myschema.user_transactions,将其连接到我们的 BI 解决方案。

我们希望使用 MERGE 语句将新数据插入到该表中。

CREATE EXTERNAL TABLE mydatabase.users (
    id                bigint
  , registration_date string
) 
ROW FORMAT DELIMITED
FIELDS TERMINATED BY ',' 
STORED AS INPUTFORMAT   'org.apache.hadoop.mapred.TextInputFormat'
OUTPUTFORMAT   'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' 
LOCATION  's3://<YOUR_S3_BUCKET>/data/myschema/users/' TBLPROPERTIES (  'skip.header.line.count'='0')
;
select * from mydatabase.users;

CREATE EXTERNAL TABLE mydatabase.transactions (
    transaction_id    bigint
  , user_id           bigint
  , total_cost        double
  , dt                string
) 
ROW FORMAT DELIMITED
FIELDS TERMINATED BY ',' 
STORED AS INPUTFORMAT   'org.apache.hadoop.mapred.TextInputFormat'
OUTPUTFORMAT   'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' 
LOCATION  's3://<YOUR_S3_BUCKET>/data/myschema/transactions/' TBLPROPERTIES (  'skip.header.line.count'='0')
;
select * from mydatabase.transactions;

CREATE TABLE IF NOT EXISTS mydatabase.user_transactions (
  dt date,
  user_id int,
  total_cost_usd float,
  registration_date string
) 
PARTITIONED BY (dt)
LOCATION 's3://<YOUR_S3_BUCKET>/data/myschema/optimized-data-iceberg-parquet/' 
TBLPROPERTIES (
  'table_type'='ICEBERG',
  'format'='parquet',
  'write_target_data_file_size_bytes'='536870912',
  'optimize_rewrite_delete_file_threshold'='10'
)
;

MERGE INTO mydatabase.user_transactions  as ut
USING (
  SELECT 
      date(dt) dt
    , user_id
    , sum(total_cost) total_cost_usd
    , registration_date
  FROM mydatabase.transactions 
  LEFT JOIN mydatabase.users
  ON users.id = transactions.user_id
  GROUP BY
    dt
    , user_id
    , registration_date
) as ut2
ON (ut.dt = ut2.dt and ut.user_id = ut2.user_id)
WHEN MATCHED
    THEN UPDATE
        SET total_cost_usd = ut2.total_cost_usd, registration_date = ut2.registration_date
WHEN NOT MATCHED 
THEN INSERT (
 dt
,user_id
,total_cost_usd
,registration_date
)
  VALUES (
 ut2.dt
,ut2.user_id
,ut2.total_cost_usd
,ut2.registration_date
)
;

当新表准备好后,我们可以通过运行 SELECT * 来检查它:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

mydatabase.user_transactions。图片由作者提供。

使用 Step Functions(状态机)编排数据管道

在之前的步骤中,我们学习了如何分别部署数据管道的每一步并进行测试。在这一段中,我们将了解如何使用基础设施代码和管道编排工具如 AWS Step Functions(状态机)创建一个完整的数据管道。当我们完成时,管道图将如下所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用 Step Functions 进行数据管道编排。图像由作者提供。

数据管道编排是一种很好的数据工程技术,它为我们的数据管道增加了互动性。这个想法在我之前的一篇故事中已经解释过[5]:

## 数据管道编排

数据管道管理得当可以简化部署并提高数据的可用性和可访问性……

[towardsdatascience.com

要部署完整的编排器解决方案,包括所有必要的资源,我们可以使用 CloudFormation(基础设施即代码)。考虑下面这个可以在/stack文件夹中从命令行运行的脚本。确保<YOUR_S3_BUCKET>存在,并将其替换为您的实际 S3 桶:

#!/usr/bin/env bash
# chmod +x ./deploy-staging.sh
# Run ./deploy-staging.sh
PROFILE=<YOUR_AWS_PROFILE>
STACK_NAME=BatchETLpipeline
LAMBDA_BUCKET=<YOUR_S3_BUCKET> # Replace with unique bucket name in your account
APP_FOLDER=mysql_connector

date

TIME=`date +"%Y%m%d%H%M%S"`

base=${PWD##*/}
zp=$base".zip"
echo $zp

rm -f $zp

pip install --target ./package -r requirements.txt
# boto3 is not required unless we want a specific version for Lambda
# requirements.txt:
# pymysql==1.0.3
# requests==2.28.1
# pytz==2023.3
# pyyaml==6.0

cd package
zip -r ../${base}.zip .

cd $OLDPWD

zip -r $zp "./${APP_FOLDER}" -x __pycache__ 

# Check if Lambda bucket exists:
LAMBDA_BUCKET_EXISTS=$(aws --profile ${PROFILE} s3 ls ${LAMBDA_BUCKET} --output text)
#  If NOT:
if [[ $? -eq 254 ]]; then
    # create a bucket to keep Lambdas packaged files:
    echo  "Creating Lambda code bucket ${LAMBDA_BUCKET} "
    CREATE_BUCKET=$(aws --profile ${PROFILE} s3 mb s3://${LAMBDA_BUCKET} --output text)
    echo ${CREATE_BUCKET}
fi

# Upload the package to S3:
aws --profile $PROFILE s3 cp ./${base}.zip s3://${LAMBDA_BUCKET}/${APP_FOLDER}/${base}${TIME}.zip

aws --profile $PROFILE \
cloudformation deploy \
--template-file stack.yaml \
--stack-name $STACK_NAME \
--capabilities CAPABILITY_IAM \
--parameter-overrides \
"StackPackageS3Key"="${APP_FOLDER}/${base}${TIME}.zip" \
"AppFolder"=$APP_FOLDER \
"S3LambdaBucket"=$LAMBDA_BUCKET \
"Environment"="staging" \
"Testing"="false"

它将使用 stack.yaml 创建一个名为 BatchETLpipeline 的 CloudFormation 堆栈。它将打包我们的 Lambda 函数,创建一个包并将其上传到 S3 桶中。如果该桶不存在,它将创建它。然后将部署管道。

AWSTemplateFormatVersion: '2010-09-09'
Description: An example template for a Step Functions state machine.
Parameters:

  DataLocation:
    Description: Data lake bucket with source data files.
    Type: String
    Default: s3://your.datalake.aws/data/
  AthenaResultsLocation:
    Description: S3 location for Athena query results.
    Type: String
    Default: s3://your.datalake.aws/athena/
  AthenaDatabaseName:
    Description: Athena schema names for ETL pipeline.
    Type: String
    Default: mydatabase
  S3LambdaBucket:
    Description: Use this bucket to keep your Lambda package.
    Type: String
    Default: your.datalake.aws
  StackPackageS3Key:
    Type: String
    Default: mysql_connector/stack.zip
  ServiceName:
    Type: String
    Default: mysql-connector
  Testing:
    Type: String
    Default: 'false'
    AllowedValues: ['true','false']
  Environment:
    Type: String
    Default: 'staging'
    AllowedValues: ['staging','live','test']
  AppFolder:
    Description: app.py file location inside the package, i.e. mysql_connector when ./stack/mysql_connector/app.py.
    Type: String
    Default: mysql_connector

Resources:
  LambdaExecutionRole:
    Type: "AWS::IAM::Role"
    Properties:
      AssumeRolePolicyDocument:
        Version: "2012-10-17"
        Statement:
          - Effect: Allow
            Principal:
              Service: lambda.amazonaws.com
            Action: "sts:AssumeRole"

  MyLambdaFunction:
    Type: "AWS::Lambda::Function"
    Properties:
      Handler: "index.handler"
      Role: !GetAtt [ LambdaExecutionRole, Arn ]
      Code:
        ZipFile: |
          exports.handler = (event, context, callback) => {
              callback(null, "Hello World!");
          };
      Runtime: "nodejs18.x"
      Timeout: "25"

### MySQL Connector Lmabda ###
  MySqlConnectorLambda:
    Type: AWS::Lambda::Function
    DeletionPolicy: Delete
    DependsOn: LambdaPolicy
    Properties:
      FunctionName: !Join ['-', [!Ref ServiceName, !Ref Environment] ]
      Handler: !Sub '${AppFolder}/app.lambda_handler'
      Description: Microservice that extracts data from RDS.
      Environment:
        Variables:
          DEBUG: true
          LAMBDA_PATH: !Sub '${AppFolder}/'
          TESTING: !Ref Testing
          ENV: !Ref Environment
      Role: !GetAtt LambdaRole.Arn
      Code:
        S3Bucket: !Sub '${S3LambdaBucket}'
        S3Key:
          Ref: StackPackageS3Key
      Runtime: python3.8
      Timeout: 360
      MemorySize: 128
      Tags:
        -
          Key: Service
          Value: Datalake

  StatesExecutionRole:
    Type: "AWS::IAM::Role"
    Properties:
      AssumeRolePolicyDocument:
        Version: "2012-10-17"
        Statement:
          - Effect: "Allow"
            Principal:
              Service:
                - !Sub states.${AWS::Region}.amazonaws.com
            Action: "sts:AssumeRole"
      Path: "/"
      Policies:
        - PolicyName: StatesExecutionPolicy
          PolicyDocument:
            Version: "2012-10-17"
            Statement:
              - Effect: Allow
                Action:
                  - "lambda:InvokeFunction"
                Resource: "*"
              - Effect: Allow
                Action:
                  - "athena:*"

                Resource: "*"
              - Effect: Allow
                Action:
                  - "s3:*"
                Resource: "*"
              - Effect: Allow
                Action:
                  - "glue:*"
                Resource: "*"

  MyStateMachine:
    Type: AWS::StepFunctions::StateMachine
    Properties:
      # StateMachineName: ETL-StateMachine
      StateMachineName: !Join ['-', ['ETL-StateMachine', !Ref ServiceName, !Ref Environment] ]
      DefinitionString:
        !Sub
          - |-
            {
              "Comment": "A Hello World example using an AWS Lambda function",
              "StartAt": "HelloWorld",
              "States": {
                "HelloWorld": {
                  "Type": "Task",
                  "Resource": "${lambdaArn}",
                  "Next": "Extract from MySQL"
                },
                "Extract from MySQL": {
                  "Resource": "${MySQLLambdaArn}",
                  "Type": "Task",
                  "Next": "Create Athena DB"
                },
                "Create Athena DB": {
                  "Resource": "arn:aws:states:::athena:startQueryExecution.sync",
                  "Parameters": {
                    "QueryString": "CREATE DATABASE if not exists ${AthenaDatabaseName}",
                    "WorkGroup": "primary",
                    "ResultConfiguration": {
                      "OutputLocation": "${AthenaResultsLocation}"
                    }
                  },
                  "Type": "Task",
                  "Next": "Show tables"
                },
                "Show tables": {
                  "Resource": "arn:aws:states:::athena:startQueryExecution.sync",
                  "Parameters": {
                    "QueryString": "show tables in ${AthenaDatabaseName}",
                    "WorkGroup": "primary",
                    "ResultConfiguration": {
                      "OutputLocation": "${AthenaResultsLocation}"
                    }
                  },
                  "Type": "Task",
                  "Next": "Get Show tables query results"
                },
                "Get Show tables query results": {
                  "Resource": "arn:aws:states:::athena:getQueryResults",
                  "Parameters": {
                    "QueryExecutionId.$": "$.QueryExecution.QueryExecutionId"
                  },
                  "Type": "Task",
                  "Next": "Decide what next"
                },
                "Decide what next": {
                  "Comment": "Based on the input table name, a choice is made for moving to the next step.",
                  "Type": "Choice",
                  "Choices": [
                    {
                      "Not": {
                        "Variable": "$.ResultSet.Rows[0].Data[0].VarCharValue",
                        "IsPresent": true
                      },
                      "Next": "Create users table (external)"
                    },
                    {
                      "Variable": "$.ResultSet.Rows[0].Data[0].VarCharValue",
                      "IsPresent": true,
                      "Next": "Check All Tables"
                    }
                  ],
                  "Default": "Check All Tables"
                },
                "Create users table (external)": {
                  "Resource": "arn:aws:states:::athena:startQueryExecution.sync",
                  "Parameters": {
                    "QueryString": "CREATE EXTERNAL TABLE ${AthenaDatabaseName}.users ( id                bigint , registration_date string ) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS INPUTFORMAT   'org.apache.hadoop.mapred.TextInputFormat' OUTPUTFORMAT   'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' LOCATION  's3://datalake.staging.liveproject/data/myschema/users/' TBLPROPERTIES (  'skip.header.line.count'='0') ;",
                    "WorkGroup": "primary",
                    "ResultConfiguration": {
                      "OutputLocation": "${AthenaResultsLocation}"
                    }
                  },
                  "Type": "Task",
                  "Next": "Create transactions table (external)"
                },
                "Create transactions table (external)": {
                  "Resource": "arn:aws:states:::athena:startQueryExecution.sync",
                  "Parameters": {
                    "QueryString": "CREATE EXTERNAL TABLE ${AthenaDatabaseName}.transactions ( transaction_id    bigint , user_id           bigint , total_cost        double , dt                string ) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS INPUTFORMAT   'org.apache.hadoop.mapred.TextInputFormat' OUTPUTFORMAT   'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' LOCATION  's3://datalake.staging.liveproject/data/myschema/transactions/' TBLPROPERTIES (  'skip.header.line.count'='0') ;",
                    "WorkGroup": "primary",
                    "ResultConfiguration": {
                      "OutputLocation": "${AthenaResultsLocation}"
                    }
                  },
                  "Type": "Task",
                  "Next": "Create report table (parquet)"
                },
                "Create report table (parquet)": {
                  "Resource": "arn:aws:states:::athena:startQueryExecution.sync",
                  "Parameters": {
                    "QueryString": "CREATE TABLE IF NOT EXISTS ${AthenaDatabaseName}.user_transactions ( dt date, user_id int, total_cost_usd float, registration_date string ) PARTITIONED BY (dt) LOCATION 's3://datalake.staging.liveproject/data/myschema/optimized-data-iceberg-parquet/' TBLPROPERTIES ( 'table_type'='ICEBERG', 'format'='parquet', 'write_target_data_file_size_bytes'='536870912', 'optimize_rewrite_delete_file_threshold'='10' ) ;",
                    "WorkGroup": "primary",
                    "ResultConfiguration": {
                      "OutputLocation": "${AthenaResultsLocation}"
                    }
                  },
                  "Type": "Task",
                  "End": true
                },
                "Check All Tables": {
                  "Type": "Map",
                  "InputPath": "$.ResultSet",
                  "ItemsPath": "$.Rows",
                  "MaxConcurrency": 0,
                  "Iterator": {
                    "StartAt": "CheckTable",
                    "States": {
                      "CheckTable": {
                        "Type": "Choice",
                        "Choices": [
                          {
                            "Variable": "$.Data[0].VarCharValue",
                            "StringMatches": "*users",
                            "Next": "passstep"
                          },
                          {
                            "Variable": "$.Data[0].VarCharValue",
                            "StringMatches": "*user_transactions",
                            "Next": "Insert New parquet Data"
                          }
                        ],
                        "Default": "passstep"
                      },
                      "Insert New parquet Data": {
                        "Resource": "arn:aws:states:::athena:startQueryExecution.sync",
                        "Parameters": {
                          "QueryString": "MERGE INTO ${AthenaDatabaseName}.user_transactions  as ut USING ( SELECT date(dt) dt , user_id , sum(total_cost) total_cost_usd , registration_date FROM ${AthenaDatabaseName}.transactions LEFT JOIN ${AthenaDatabaseName}.users ON users.id = transactions.user_id GROUP BY dt , user_id , registration_date ) as ut2 ON (ut.dt = ut2.dt and ut.user_id = ut2.user_id) WHEN MATCHED THEN UPDATE SET total_cost_usd = ut2.total_cost_usd, registration_date = ut2.registration_date WHEN NOT MATCHED THEN INSERT ( dt ,user_id ,total_cost_usd ,registration_date ) VALUES ( ut2.dt ,ut2.user_id ,ut2.total_cost_usd ,ut2.registration_date ) ;",
                          "WorkGroup": "primary",
                          "ResultConfiguration": {
                            "OutputLocation": "${AthenaResultsLocation}"
                          }
                        },
                        "Type": "Task",
                        "End": true
                      },
                      "passstep": {
                        "Type": "Pass",
                        "Result": "NA",
                        "End": true
                      }
                    }
                  },
                  "End": true
                }
              }
            }
          - {
            lambdaArn: !GetAtt [ MyLambdaFunction, Arn ],
            MySQLLambdaArn: !GetAtt [ MySqlConnectorLambda, Arn ],
            AthenaResultsLocation: !Ref AthenaResultsLocation,
            AthenaDatabaseName: !Ref AthenaDatabaseName
          }
      RoleArn: !GetAtt [ StatesExecutionRole, Arn ]
      Tags:
        -
          Key: "keyname1"
          Value: "value1"
        -
          Key: "keyname2"
          Value: "value2"

# IAM role for mysql-data-connector Lambda:
  LambdaRole:
    Type: AWS::IAM::Role
    Properties:
      AssumeRolePolicyDocument:
        Version: "2012-10-17"
        Statement:
          -
            Effect: Allow
            Principal:
              Service:
                - "lambda.amazonaws.com"
            Action:
              - "sts:AssumeRole"

  LambdaPolicy:
    Type: AWS::IAM::Policy
    DependsOn: LambdaRole
    Properties:
      Roles:
        - !Ref LambdaRole
      PolicyName: !Join ['-', [!Ref ServiceName, !Ref Environment, 'lambda-policy']] 
      PolicyDocument:
        {
            "Version": "2012-10-17",
            "Statement": [
                {
                    "Effect": "Allow",
                    "Action": [
                        "logs:CreateLogGroup",
                        "logs:CreateLogStream",
                        "logs:PutLogEvents"
                    ],
                    "Resource": "*"
                }
            ]
        }

如果一切顺利,我们的新数据管道的堆栈将被部署:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

BatchETLpipeline 堆栈和资源。图像由作者提供。

如果我们点击状态机资源,然后点击‘编辑’,我们将看到我们的 ETL 管道作为图形展示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

批量数据管道的工作流工作室。图像由作者提供。

现在我们可以执行管道以运行所有必要的数据转换步骤。点击‘开始执行’。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

成功执行。图像由作者提供。

现在我们可以将我们的 Athena 表连接到我们的BI 解决方案。连接我们最终的 Athena 数据集mydataset.user_transactions以创建仪表盘。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

连接 Quicksight 中的数据集。图像由作者提供。

我们只需调整几个设置,使我们的仪表盘看起来像这样:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Quicksight 仪表盘。图像由作者提供。

我们希望使用dt作为维度,total_cost_usd作为指标。我们还可以为每个user_id设置一个拆分维度。

结论

批处理数据管道很受欢迎,因为历史上工作负载主要是批处理型的数据环境。我们刚刚建立了一个 ETL 数据管道,从 MySQL 中提取数据并在数据湖中转换。该模式最适用于数据集不大且需要持续处理的情况,因为 Athena 根据扫描的数据量收费。这种方法在将数据转换为列式格式如 Parquet 或 ORC 时表现良好,结合几个小文件成较大的文件,或进行分桶和添加分区。我以前在我的一个故事中写过这些大数据文件格式[6]。

## 大数据文件格式解析

Parquet 与 ORC 与 AVRO 与 JSON。该选择哪一个,如何使用它们?

towardsdatascience.com

我们学习了如何使用 Step Functions 来编排数据管道,视觉化数据流从源头到最终用户,并使用基础设施即代码进行部署。这个设置使得我们可以对数据管道应用 CI/CD 技术[7]。

希望这个教程对你有帮助。如果你有任何问题,请告诉我。

推荐阅读

[1] towardsdatascience.com/data-pipeline-design-patterns-100afa4b93e3

[2] medium.com/towards-data-science/introduction-to-apache-iceberg-tables-a791f1758009

[3] medium.com/towards-data-science/advanced-sql-techniques-for-beginners-211851a28488

[4] medium.com/towards-data-science/create-mysql-and-postgres-instances-using-aws-cloudformation-d3af3c46c22a

[5] medium.com/towards-data-science/data-pipeline-orchestration-9887e1b5eb7a

[6] medium.com/towards-data-science/big-data-file-formats-explained-275876dc1fc9

[7] medium.com/towards-data-science/continuous-integration-and-deployment-for-data-platforms-817bf1b6bed1

使用 Hugging Face 的 Transformer 模型构建评论毒性排序器

原文:towardsdatascience.com/building-a-comment-toxicity-ranker-using-hugging-faces-transformer-models-aa5b4201d7c6

赶上 NLP 和 LLM(第一部分)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Jacky Kaub

·发表于Towards Data Science ·18 分钟阅读·2023 年 8 月 6 日

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

图片来源:Brett JordanUnsplash

介绍

作为一名数据科学家,我从未有机会深入探索自然语言处理的最新进展。随着夏季和今年初大语言模型的新热潮,我决定是时候深入这个领域并开始一些小项目。毕竟,没有比实践更好的学习方法了。

在我的旅程开始时,我意识到很难找到能够手把手引导读者、一步一步深入理解新 NLP 模型并通过具体项目进行的内容。因此,我决定开始这一新系列的文章。

使用 HuggingFace 的 Transformer 模型构建评论毒性排序器

在这篇文章中,我们将深入探讨构建评论毒性排序器。这个项目灵感来源于去年在 Kaggle 上举办的“Jigsaw 毒性评论严重性评估”竞赛

竞赛的目标是构建一个能够判断哪个评论(在给定的两个评论中)最具毒性的模型。

为此,模型会为每个输入的评论分配一个分数,以确定其相对毒性。

本文涵盖内容

在这篇文章中,我们将使用 Pytorch 和 Hugging Face transformers 训练我们的第一个 NLP 分类器。我不会深入讲解 transformers 的工作原理,而是更多地关注实际细节和实现,并引入一些对系列后续文章有用的概念。

具体来说,我们将看到:

  • 如何从 Hugging Face Hub 下载模型

  • 如何自定义和使用编码器

  • 从 Hugging Face 模型中构建并训练一个 Pytorch 排名器

本文直接面向希望从实际角度提升其自然语言处理技能的数据科学家。我不会详细讲解变换器的理论,即使我会详细编写代码,也希望你已经对 PyTorch 有一些了解。

探索与架构

训练数据集

我们将处理一个将评论配对并将其分类为“较少毒性”和“更多毒性”的数据集。

相对毒性的选择是由一组标注者做出的。

下图显示了来自训练集的数据样本。工人字段表示进行分类的标注者的 id。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

训练集样本,作者插图

注意:数据集在开源许可证下提供,遵循Kaggle 竞赛规则

排名系统

在任何机器学习项目中,理解任务具有至关重要的意义,因为它显著影响了合适模型和策略的选择。这种理解应从项目启动时就建立起来。

在这个具体的项目中,我们的目标是构建一个排名系统。与其预测一个具体的目标,我们的重点是确定一个任意值,以便在样本对之间进行有效比较。

让我们首先绘制一个基本的图示来表示这个概念,知道我们稍后会更深入地探讨“模型”的工作原理。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们想要实现的一个非常基本的视图

以这种方式可视化任务至关重要,因为它表明项目的目标不仅仅是基于训练数据训练一个简单的二分类器。与其仅仅预测 0 或 1 来识别最有毒的评论,排名系统旨在分配任意值,从而有效地比较评论。

模型训练与边际排名损失

考虑到“模型”仍然是一个黑箱神经网络,我们需要建立一种利用这个系统并利用我们的训练数据来更新模型权重的方法。为此,我们需要一个合适的损失函数。

鉴于我们的目标是构建一个排名系统,边际排名损失是一个相关的选择。这个损失函数受到铰链损失的启发,后者通常用于优化样本之间的最大边际。

边际排名损失对样本对进行操作。对于每一对样本,它比较“模型”对两个样本产生的分数,并强制它们之间有一个边际。这个边际表示正确排序的样本之间期望的分数差异。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

Margin Ranking Loss 函数公式,作者插图

在上述公式中,x1 和 x2 是两个样本的排名得分,y 是一个系数,如果 x1 应该排名高于 x2,则 y 等于 1,否则为 -1。 “margin” 是公式的超参数,设置了需要达到的最小间隔。

让我们看看这个损失函数是如何工作的:

假设 y=1,这意味着与 x1 相关的样本应该比与 x2 相关的样本排名更高:

  1. 如果 (x1 — x2) > margin,样本 1 的得分比样本 2 的得分高出足够的间隔,则 max() 的右侧项为负数。返回的损失将等于 0,并且这两个排名之间没有惩罚。

  2. 如果 (x1 — x2) < margin,这意味着 x1 和 x2 之间的间隔不足,或者更糟的是,x2 的得分高于 x1 的得分。在这种情况下,损失会更高,因为样本 2 的得分高于样本 1 的得分,这会惩罚模型。

鉴于此,我们现在可以按照如下修订我们的训练方法:

对于训练集中的一个样本(或一个批次):

  1. 将 more_toxic 消息传递给模型,得到 Rank_Score1 (x1)

  2. 将 less_toxic 消息传递给模型,得到 Rank_Score2 (x2)

  3. 计算 y = 1 时的 MarginRankingLoss

  4. 根据计算出的损失更新模型的权重(反向传播步骤)

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

使用 Margin Ranking Loss 的模型训练步骤,作者插图

从文本到特征表示:编码器块

我们的训练程序现在已设置完成。是时候深入了解‘模型’组件本身了。在 NLP 的世界里,你会经常遇到三种主要类型的模型:编码器、解码器和编码器-解码器组合。在这一系列文章中,我们将更详细地研究这些类型的模型。

对于本特定文章的目的,我们需要一个可以将消息转换为特征向量的模型。这个向量作为输入生成最终的排名得分。这个特征向量将直接从变换器架构的编码器中派生。

我不会在这里深入理论,因为其他人已经解释得很好(我推荐 Hugging Face 的入门课程,写得非常好)。只需记住这个过程的关键部分叫做注意力机制。它通过查看其他相关词,即使它们相隔很远,也帮助变换器理解文本。

有了这种架构,我们将能够调整权重,以生成我们文本的最佳向量表示,从而识别出对任务最重要的特征,并将最终层从变换器连接到一个最终节点(称为“头”),该节点将生成最终的排名得分。

让我们相应地更新我们的图示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

我们训练流水线的更新视图,作者插图

分词器

正如你从上述图表中看到的,模型内部出现了一个我们尚未提及的组件:预处理步骤。

这个预处理步骤旨在将原始文本转换为可以通过神经网络处理的内容(数字),而这就是分词器的作用。

分词器的主要功能有两个:分割(即将文本切割成片段,这些片段可以是单词、单词的一部分或字母)和索引(即将每个文本片段映射到一个唯一的值,该值在字典中引用,以便可以反向操作)。

需要记住的一件非常重要的事情是,文本的分词方式有多种,但如果你使用预训练模型,你需要使用相同的分词器,否则预训练权重将毫无意义(由于不同的分割和索引)。

另一个重要的事情是要记住,编码器只是一个神经网络。因此,它的输入需要是固定大小的,但你的输入文本不一定符合这一点。分词器允许你通过两个操作来控制你的词向量的大小:填充和截断。这也是一个重要的参数,因为一些预训练模型会使用更小或更大的输入空间。

在下面的图中,我们添加了分词器,并展示了消息如何从模块到模块进行转换。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

最终训练示意图,作者插图

就这样,我们已经揭示了所有需要了解的组件,以便有效地处理我们的“评论毒性排名”任务。总结上述图表:每对消息(较少毒性和较多毒性)将分别传递给模型流水线。它们将首先经过分词器、编码器和排名层,以产生一对分数。这对分数将用于计算边际排名损失,这将用于反向传播步骤中,更新编码器和最终排名层的权重,并优化它们以完成任务。

在下一部分,我们将亲自动手编写代码,使用 Hugging Face transformers 模块和 Pytorch 构建上述流水线。

构建、训练和评估模型

我们在前面的部分中涵盖了理论,现在是时候亲自动手,开始处理我们的模型了。

虽然过去构建和训练复杂的深度学习模型可能很复杂,但新的现代框架使其变得更简单。

Hugging Face 是你所需的一切

Hugging Face 是一家了不起的公司,致力于使复杂的深度学习模型民主化。

它们构建了帮助你构建、加载、微调和共享复杂变换器模型的抽象。

在接下来的部分中,我们将使用他们的transformers包,该包提供了构建预训练 NLP 模型并用于自己任务所需的所有工具。在接下来的几周内,我们将更详细地探索该包提供的不同可能性

该包与 TensorFlow 和 PyTorch 库兼容。

首先,让我们安装 transformers 包

pip install transformers

从 Hugging Face 获取的模型可以在他们的Model Hub网站上找到。你可以找到各种类型的模型以及描述,以了解模型的功能、参数数量、训练数据集等。

在本文中,我们将使用架构roberta-base,这是一个相对轻量的编码器,经过多个英文语料库的训练。

模型描述提供了大量与我们的任务相关的非常重要的信息:

  • 该模型具有 125M 个参数

  • 该模型已在多个英文语料库上进行过训练,这一点很重要,因为我们的评论数据集是英文的

  • 该模型已经在掩蔽语言模型的目标上进行过训练,这一目标是尝试预测文本中被掩蔽的单词,并使用前后的文本进行预测,这并非总是如此(例如,GPT 等模型只使用单词前的上下文来进行预测,因为它们在推断新文本时无法看到句子的未来)。

  • 该模型对大小写敏感,这意味着它会区分“WORD”和“word”。这在毒性检测器中尤为重要,因为字母的大小写是判断毒性的一个重要线索。

Hugging Face 可以为每个模型提供使用的分词器以及不同配置的基本神经网络(你可能不希望所有的权重:有时你只想限制在编码器部分,解码器部分,停留在隐藏层等)。

从 Hugging Face hub 获取的模型可以在本地克隆(这样运行会更快)或直接在代码中加载,通过使用其 repo id(例如我们案例中的 roberta-base)

加载和测试分词器

要加载分词器,我们可以简单地使用 transformers 包中的 AutoTokenizer 类,并指定我们想要使用的分词器

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('roberta-base')

为了对文本进行分词,我们可以简单地调用“encode”或“encode_plus”方法。“encode_plus”不仅会提供你文本的分词版本,还会提供一个注意力掩码,用于忽略纯填充部分的编码。

text = "hello world"

tokenizer.encode_plus(
    text,
    truncation=True,
    add_special_tokens=True,
    max_length=10,
    padding='max_length'
    )

将返回一个字典,其中“input_ids”是编码序列,“attention_mask”用于允许变换器忽略填充的标记:

{
  'input_ids': [0, 42891, 232, 2, 1, 1, 1, 1, 1, 1], 
  'attention_mask': [1, 1, 1, 1, 0, 0, 0, 0, 0, 0]
}

在我们使用的参数中,有:

  • max_length: 指定编码序列的最大长度

  • add_special_tokens: 向文本中添加和标记

  • truncation: 如果文本不适合 max_length,则会截断文本

  • padding: 添加填充标记直到 max_length

加载预训练模型

要加载预训练模型,Hugging Face 提供了多个类,具体取决于你的需求(你是在使用 TensorFlow 还是 Pytorch?你尝试实现什么类型的任务)。

在我们的案例中,我们将使用 AutoModel,它允许你直接加载模型架构及预训练权重。请注意,如果你使用 TensorFlow,你可以通过使用 TFAutoModel 类而不是 AutoModel 类来实现相同的功能。

AutoModel 类将直接从 RobertaModel 加载模型架构,并加载与 Hugging Face 中的 “roberta-base” 仓库相关联的预训练权重。

至于 Tokenizer,我们可以直接从 repo-id 或本地仓库路径加载模型,通过使用 AutoModel 的 from_pretrained 方法:

from transformers import AutoModel

robertaBase = AutoModel.from_pretrained("roberta-base")

请注意,编码器没有在特定任务上进行训练,我们不能简单地使用模型。相反,我们需要用我们的数据集进行微调。

我们可以再三检查 robertaBase 是否是 pytorch.nn.Module 的实例,并且可以集成到更复杂的 PyTorch 架构中:

import pytorch.nn as torch

isinstance(robertaBase, nn.Module)
>> True

你也可以通过简单地打印它来检查其架构,就像你对待标准 PyTorch 模块一样:

print(robertaBase)

>> RobertaModel(
  (embeddings): RobertaEmbeddings(
    (word_embeddings): Embedding(50265, 768, padding_idx=1)
    (position_embeddings): Embedding(514, 768, padding_idx=1)
    (token_type_embeddings): Embedding(1, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): RobertaEncoder(
    (layer): ModuleList(
      (0-11): 12 x RobertaLayer(
        (attention): RobertaAttention(
          (self): RobertaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): RobertaSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): RobertaIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): RobertaOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pooler): RobertaPooler(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
  )
)

构建自定义神经网络

这个最后的层实际上是我们在本文第一部分讨论的整个文本的向量表示,我们只需将其连接到用于排序的最终节点,以完成我们的神经网络架构。

为此,我们将通过封装 nn.Module 来简单地构建自己的自定义模块,就像我们用 PyTorch 构建经典神经网络一样。

model_name = "roberta-base"
last_hidden_layer_size = 768
final_node_size = 1

class ToxicRankModel(nn.Module):

    def __init__(self, model_name, last_hidden_layer_size):
        super(ToxicRankModel, self).__init__()
        self.robertaBase = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(p=0.1)
        self.rank_head = nn.Linear(last_hidden_layer_size, 1)

    def forward(self, ids, mask):        
        output = self.robertaBase(input_ids=ids,attention_mask=mask,
                         output_hidden_states=False)
        output = self.dropout(output[1])
        score= self.fc(output)
        return score

#This line check if the GPU is available, else it goes with the CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#After initiation, we send the model to the device
toxicRankModel = ToxicRankModel(model_name, last_hidden_layer_size)
toxicRankModel = toxicRankModel.to(device)

在 forward() 方法中需要注意几点:

  1. 我们将两个主要输入传递给 robertBase 模型,input_ids 和 attention_mask。它们都是由 Tokenizer 生成的。

  2. AutoModel 具有参数(如 output_hidden_states)。根据你选择的参数,你可以让模型作为编码器或解码器运行,并将模型定制用于不同的 NLP 任务。

  3. 你是否注意到我们在 dropout 中传递了 output[1]?这是因为基本模型提供了两个输入:

  • 首先,最后的隐藏状态,它包含每个标记的上下文表示(或上下文嵌入),可以用于实体识别等任务。

  • 其次,来自 Pooler 的输出,它包含整个文本的向量表示,就是我们在这里寻找的。

构建自定义数据集

使用 Pytorch,我们还需要创建自己的 Dataset 类,用于存储原始数据,以及 DataLoader,用于在训练过程中按批次馈送神经网络。

在使用 Pytorch 构建自定义数据集时,你必须实现两个强制性方法:

  • len,它给出训练数据的大小(对数据加载器来说是重要信息)

  • getitem,它接受原始输入(来自第“i”行)并进行预处理,以便神经网络(作为张量)可以处理

如果你记得之前部分的图示,我们实际上是在计算损失之前并行传递两个输入到模型中:less_toxic 和 more_toxic。

getitem 方法将处理消息的分词,并为转换器准备输入,将分词后的输入转换为张量。

class CustomDataset(Dataset):
    def __init__(self, train_df, tokenizer, max_length):

        #token list standard size
        self.length = max_length

        #Here the tokenizer will be an instance of the tokenizer
        #shown previously
        self.tokenizer = tokenizer

        #df is the training df shown in the beginning of the article
        self.more_toxic = train_df['more_toxic'].values
        self.less_toxic = train_df['less_toxic'].values

    def __len__(self):
        return len(self.more_toxic)

    def __getitem__(self, i):
        # get both messages at index i
        message_more_toxic = self.more_toxic[i]
        message_less_toxic = self.less_toxic[i]

        #tokenize the messages
        dic_more_toxic = self.tokenizer.encode_plus(
                                message_more_toxic,
                                truncation=True,
                                add_special_tokens=True,
                                max_length=self.length,
                                padding='max_length'
                            )
        dic_less_toxic = self.tokenizer.encode_plus(
                                message_less_toxic,
                                truncation=True,
                                add_special_tokens=True,
                                max_length=self.length,
                                padding='max_length'
                            )

        #extract tokens and masks
        tokens_more_toxic = dic_more_toxic['input_ids']
        mask_more_toxic = dic_more_toxic['attention_mask']

        tokens_less_toxic = dic_less_toxic['input_ids']
        mask_less_toxic = dic_less_toxic['attention_mask']

        #return a dictionnary of tensors
        return {
            'tokens_more_toxic': torch.tensor(tokens_more_toxic, dtype=torch.long),
            'mask_more_toxic': torch.tensor(mask_more_toxic, dtype=torch.long),
            'tokens_less_toxic': torch.tensor(tokens_less_toxic, dtype=torch.long),
            'mask_less_toxic': torch.tensor(mask_less_toxic, dtype=torch.long),
        }

我们现在可以生成 DataLoader,用于模型的批量训练。

def get_loader(df, tokenizer, max_length, batch_size):

    dataset = CustomDataset(
        df, 
        tokenizer=tokenizer, 
        max_length=max_length
    )

    return DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=True,
        drop_last=True)

max_length = 128
batch_size = 32
train_loader = get_loader(train_df, tokenizer, max_length, batch_size=batch_size)
  • batch_size 指定了用于前向传递/反向传播的样本数量

  • shuffle = True 意味着数据集在两个 epoch 之间会被打乱

  • drop_last 意味着如果最后一个 batch 没有正确数量的样本,它将被丢弃。这一点很重要,因为 batch normalization 对于不完整的 batch 处理效果不好。

训练模型

我们快完成了,现在是时候为一个 epoch 准备训练流程了。

自定义损失

首先,让我们定义一个自定义损失函数。 Pytorch 已经提供了 MarginRankingLoss,我们只是将其封装为 y = 1(因为我们将始终将 more_toxic 作为 x1,less_toxic 作为 x2)。

from torch.nn import MarginRankingLoss

#Custom implementation of the MarginRankingLoss with y = 1
class CustomMarginRankingLoss(nn.Module):
    def __init__(self, margin=0):
        super(CustomMarginRankingLoss, self).__init__()
        self.margin = margin

    def forward(self, x1, x2):
        #with y=1 this is how looks the loss
        loss = torch.relu(x2 - x1 + self.margin)
        return loss.mean()

def criterion(x1, x2):
    return CustomMarginRankingLoss()(x1, x2)

优化器

对于这个实验,我们将使用经典的 AdamW,它目前是最先进的,并解决了原始 Adam 实现的一些问题。

optimizer_lr = 1e-4
optimizer_weight_decay = 1e-6
optimizer = AdamW(toxicRankModel.parameters(), 
                  lr=optimizer_lr, 
                  weight_decay=optimizer_weight_decay)

调度器

调度器有助于调整学习率。在开始时,我们希望较高的学习率以更快地收敛到最佳解,而在训练结束时,我们希望较小的学习率以真正微调权重。

scheduler_T_max = 500
scheduler_eta_min = 1e-6
scheduler = lr_scheduler.CosineAnnealingLR(optimizer,T_max=scheduler_T_max, eta_min=scheduler_eta_min)

训练例程

我们现在准备好训练我们的 NLP 模型以进行毒性评论排序。

使用 Pytorch 训练一个 epoch 非常简单:

  1. 我们迭代通过我们的数据加载器,它会从数据集中打乱并选择预处理的数据

  2. 我们从数据加载器中提取 tokens 和 masks

  3. 我们通过对模型进行前向传递来计算每条消息的排名

  4. 当两个排名都计算完毕后,我们可以计算 MarginRankingLoss(用于反向传播),以及一个准确率分数,表示正确分类的对数百分比(仅供参考)

  5. 我们更新我们的系统(反向传播、优化器和调度器)

  6. 我们迭代直到数据加载器中的所有数据都被使用完。

def train_one_epoch(model, optimizer, scheduler, dataloader, device):

    #Setup train mode, this is important as some layers behave differently
    # during train and inference (like batch norm)
    model.train()

    #Initialisation of some loss
    dataset_size = 0
    running_loss = 0.0
    running_accuracy = 0.0

    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader), desc="Training")

    for i, data in progress_bar:
        more_toxic_ids = data['tokens_more_toxic'].to(device, dtype = torch.long)
        more_toxic_mask = data['mask_more_toxic'].to(device, dtype = torch.long)
        less_toxic_ids = data['tokens_less_toxic'].to(device, dtype = torch.long)
        less_toxic_mask = data['mask_less_toxic'].to(device, dtype = torch.long)

        batch_size = more_toxic_ids.size(0)

        #Forward pass both inputs in the model
        x1 = model(more_toxic_ids, more_toxic_mask)
        x2 = model(less_toxic_ids, less_toxic_mask)

        #Compute margin ranking loss
        loss = criterion(x1, x2)
        accuracy_measure = (x1 > x2).float().mean().item()

        #apply backpropagation, increment optimizer
        loss.backward()
        optimizer.step()
        scheduler.step()

        optimizer.zero_grad()

        #Update cumulative loss for monitoring
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size

        epoch_loss = running_loss / dataset_size

        running_accuracy += (accuracy_measure * batch_size)
        epoch_accuracy = running_accuracy / dataset_size

        progress_bar.set_postfix({'loss': epoch_loss, 'accuracy': epoch_accuracy}, refresh=True)        

    #Garbage collector
    gc.collect()

    return epoch_loss

我在 Kaggle 的 GPU T4 上训练了模型,使我获得了 70% 的评论正确分类的可观成绩。我可能通过调整不同的参数和使用更多的 epochs 提高准确性,但这对于本文的目的来说已经足够了。

关于推断的最后一点

我们建立的框架在从预格式化的评论集合中训练时效果很好。

但在“生产”场景下,这种方法就不起作用了,因为你会接收到一堆需要评估毒性评分的消息。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

这是一个生产模式下的数据集示例,在这种模式下,我们只接收单条消息,而不是消息对。

对于推断,你将设计另一个 Dataset 类和另一个 DataLoader,这些将与我们之前做的有所不同:

class CustomInferenceDataset(Dataset):
    def __init__(self, messages, tokenizer, max_length):

        #token list standard size
        self.length = max_length

        #Here the tokenizer will be an instance of the tokenizer
        #shown previously
        self.tokenizer = tokenizer

        #df is the training df shown in the beginning of the article
        self.messages = messages

    def __len__(self):
        return len(self.messages)

    def __getitem__(self, i):
        # get a message at index i
        message = self.messages[i]

        #tokenize the message
        dic_messages = self.tokenizer.encode_plus(
                                message,
                                truncation=True,
                                add_special_tokens=True,
                                max_length=self.length,
                                padding='max_length'
                            )

        #extract tokens and masks
        tokens_message = dic_messages['input_ids']
        mask_message = dic_messages['attention_mask']

        #return a dictionnary of tensors
        return {
            'tokens_message': torch.tensor(tokens_message, dtype=torch.long),
            'mask_message': torch.tensor(mask_message, dtype=torch.long),
        }

def get_loader_inference(messages, tokenizer, max_length, batch_size):

    dataset = CustomInferenceDataset(
        messages, 
        tokenizer=tokenizer, 
        max_length=max_length
    )

    return DataLoader(
        dataset, 
        batch_size=batch_size, 
        shuffle=False,
        drop_last=False)

变化了什么:

  • 我们不再加载消息对,而是单条消息。

  • Loader 没有对数据进行打乱(如果你不想要与原始向量关联的随机分数带来不好的惊喜,这一点非常重要)。

  • 由于没有批量归一化计算,并且我们希望对所有数据进行推断,我们将 drop_last 设置为 False,以获取所有批次,即使是未完成的批次。

最后,为了生成排序分数:

@torch.no_grad()
def get_scores(model, test_loader, device):
    model.eval()  # Set the model to evaluation mode
    ranks = []  # List to store the rank scores

    progress_bar = tqdm(enumerate(test_loader), total=len(test_loader), desc="Scoring")

    for i, data in progress_bar:
        tokens_message = data['tokens_message'].to(device, dtype=torch.long)
        mask_message = data['mask_message'].to(device, dtype=torch.long)

        # Forward pass to get the rank scores
        rank = model(tokens_message, mask_message)
        # Convert tensor to NumPy and add to the list
        ranks+=list(rank.cpu().numpy().flatten())

    return ranks

这是推断后的前 5 条分类消息。为了保持政治正确,我在这里进行了些许审查…

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

已识别的最具毒性的消息

不太具有建设性… 😃

结论

在这篇文章中,我们利用了 Hugging Face 预训练模型和 Pytorch 生产了一个能够对消息的毒性等级进行排序的模型。

为此,我们采用了一个“小型”的“Roberta”变换器,并使用 PyTorch 在其编码器末尾连接了一个最终简单的节点。其余部分则更为经典,可能与你之前用 PyTorch 做的其他项目类似。

这个项目是对 NLP 提供的可能性的初步探索,我想简单地介绍一些基础概念,以便进一步研究更具挑战性的任务或更大的模型。

希望你喜欢阅读,如果你想玩玩这个模型,你可以从 我的 GitHub 下载一个 Notebook。

在 Julia 中构建一个符合预测的聊天机器人

原文:towardsdatascience.com/building-a-conformal-chatbot-in-julia-1ed23363a280

符合预测、LLMs 和 HuggingFace — 第一部分

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Patrick Altmeyer

·发表于Towards Data Science ·阅读时间 7 分钟·2023 年 7 月 5 日

大型语言模型(LLM)目前非常受关注。它们被用于各种任务,包括文本分类、问答和文本生成。在本教程中,我们将展示如何使用[ConformalPrediction.jl](https://juliatrustworthyai.github.io/ConformalPrediction.jl/dev/)将变换器语言模型符合化,以进行文本分类。

👀 一览

我们特别关注意图分类任务,如下图所示。首先,我们将客户查询输入到 LLM 中以生成嵌入。接着,我们训练一个分类器,将这些嵌入与可能的意图匹配。当然,对于这个监督学习问题,我们需要由输入——查询——和输出——指示真实意图的标签——组成的训练数据。最后,我们应用符合预测来量化分类器的预测不确定性。

符合预测(CP)是一种快速发展的预测不确定性量化方法。如果你不熟悉 CP,建议你首先查看我关于这一主题的三部分介绍系列,从这篇文章开始。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

符合化意图分类器的高级概述。图片由作者提供。

🤗 HuggingFace

我们将使用Banking77数据集(Casanueva 等,2020),该数据集包含 77 个与银行相关的意图中的 13,083 个查询。在模型方面,我们将使用DistilRoBERTa模型,它是RoBERTa(Liu 等,2019)的蒸馏版,并在 Banking77 数据集上进行了微调。

可以使用 [Transformers.jl](https://github.com/chengchingwen/Transformers.jl/tree/master) 包将模型从 HF 直接加载到我们正在运行的 Julia 会话中。

这个包使得在 Julia 中使用 HF 模型变得非常简单。向开发者们致敬!🙏

下面我们加载分词器tkr和模型mod。分词器用于将文本转换为整数序列,然后将其输入模型。模型输出一个隐藏状态,然后将其输入分类器,以获得每个类别的 logits。最后,这些 logits 通过 softmax 函数以获得相应的预测概率。下面我们运行几个查询来查看模型的表现。

# Load model from HF 🤗:
tkr = hgf"mrm8488/distilroberta-finetuned-banking77:tokenizer"
mod = hgf"mrm8488/distilroberta-finetuned-banking77:ForSequenceClassification"

# Test model:
query = [
    "What is the base of the exchange rates?",
    "Why is my card not working?",
    "My Apple Pay is not working, what should I do?",
]
a = encode(tkr, query)
b = mod.model(a)
c = mod.cls(b.hidden_state)
d = softmax(c.logit)
[labels[i] for i in Flux.onecold(d)]
3-element Vector{String}:
 "exchange_rate"
 "card_not_working"
 "apple_pay_or_google_pay"

🔁 MLJ接口

由于我们的包与 [MLJ.jl](https://alan-turing-institute.github.io/MLJ.jl/dev/) 接口对接,我们需要定义一个符合MLJ接口的包装模型。为了将模型添加到通用使用中,我们可能会通过 [MLJFlux.jl](https://github.com/FluxML/MLJFlux.jl) 来实现,但在本教程中,我们将简化操作,直接重载MLJBase.fitMLJBase.predict方法。

由于 HF 的模型已经是预训练的,我们不打算进一步微调,因此我们将在MLJBase.fit方法中简单地返回模型对象。MLJBase.predict方法将接收模型对象和查询,并返回预测概率。我们还需要定义MLJBase.target_scitypeMLJBase.predict_mode方法。前者告诉MLJ模型的输出类型是什么,后者可以用来检索具有最高预测概率的标签。

struct IntentClassifier <: MLJBase.Probabilistic
    tkr::TextEncoders.AbstractTransformerTextEncoder
    mod::HuggingFace.HGFRobertaForSequenceClassification
end

function IntentClassifier(;
    tkr::TextEncoders.AbstractTransformerTextEncoder, 
    mod::HuggingFace.HGFRobertaForSequenceClassification,
)
    IntentClassifier(tkr, mod)
end

function get_hidden_state(clf::IntentClassifier, query::Union{AbstractString, Vector{<:AbstractString}})
    token = encode(clf.tkr, query)
    hidden_state = clf.mod.model(token).hidden_state
    return hidden_state
end

# This doesn't actually retrain the model, but it retrieves the classifier object
function MLJBase.fit(clf::IntentClassifier, verbosity, X, y)
    cache=nothing
    report=nothing
    fitresult = (clf = clf.mod.cls, labels = levels(y))
    return fitresult, cache, report
end

function MLJBase.predict(clf::IntentClassifier, fitresult, Xnew)
    output = fitresult.clf(get_hidden_state(clf, Xnew))= UnivariateFinite(fitresult.labels,softmax(output.logit)',pool=missing)
    return p̂
end

MLJBase.target_scitype(clf::IntentClassifier) = AbstractVector{<:Finite}

MLJBase.predict_mode(clf::IntentClassifier, fitresult, Xnew) = mode.(MLJBase.predict(clf, fitresult, Xnew))

为了测试一切是否按预期工作,我们拟合了模型并为测试数据的子集生成了预测:

clf = IntentClassifier(tkr, mod)
top_n = 10
fitresult, _, _ = MLJBase.fit(clf, 1, nothing, y_test[1:top_n])
@time= MLJBase.predict(clf, fitresult, queries_test[1:top_n]);
6.818024 seconds (11.29 M allocations: 799.165 MiB, 2.47% gc time, 91.04% compilation time)

注意,即使我们使用的 LLM 并不大,但即使是简单的前向传递也需要相当的时间。

🤖 合成聊天机器人

为了将包装好的预训练模型转变为合成意图分类器,我们现在可以依靠标准 API 调用。我们首先包装我们的原子模型,并指定所需的覆盖率和方法。由于即使是简单的前向传递对我们(小)LLM 来说也非常计算密集,我们依赖于简单归纳合成分类。

conf_model = conformal_model(clf; coverage=0.95, method=:simple_inductive, train_ratio=train_ratio)
mach = machine(conf_model, queries, y)

最后,我们使用合成 LLM 构建一个简单而强大的聊天机器人,直接在 Julia REPL 中运行。在不详细探讨细节的情况下,conformal_chatbot 的工作原理如下:

  1. 提示用户解释他们的意图。

  2. 通过合成 LLM 处理用户输入,并将输出呈现给用户。

  3. 如果合成预测集包含多个标签,请提示用户要么细化输入,要么选择预测集中的选项之一。

以下代码实现了这些想法:

function prediction_set(mach, query::String)= MLJBase.predict(mach, query)[1]
    probs = pdf.(, collect(1:77))
    in_set = findall(probs .!= 0)
    labels_in_set = labels[in_set]
    probs_in_set = probs[in_set]
    _order = sortperm(-probs_in_set)
    plt = UnicodePlots.barplot(labels_in_set[_order], probs_in_set[_order], title="Possible Intents")
    return labels_in_set, plt
end

function conformal_chatbot()
    println("👋 Hi, I'm a Julia, your conformal chatbot. I'm here to help you with your banking query. Ask me anything or type 'exit' to exit ...\n")
    completed = false
    queries = ""
    while !completed
        query = readline()
        queries = queries * "," * query
        labels, plt = prediction_set(mach, queries)
        if length(labels) > 1
            println("🤔 Hmmm ... I can think of several options here. If any of these applies, simply type the corresponding number (e.g. '1' for the first option). Otherwise, can you refine your question, please?\n")
            println(plt)
        else
            println("🥳 I think you mean $(labels[1]). Correct?")
        end

        # Exit:
        if query == "exit"
            println("👋 Bye!")
            break
        end
        if query ∈ string.(collect(1:77))
            println("👍 Great! You've chosen '$(labels[parse(Int64, query)])'. I'm glad I could help you. Have a nice day!")
            completed = true
        end
    end
end

下面我们展示了两个示例查询的输出。第一个查询非常模糊(而且刚刚发现拼写错误):“transfer mondey?”。因此,预测集的大小很大。

ambiguous_query = "transfer mondey?"
prediction_set(mach, ambiguous_query)[2]
 Possible Intents              
                                           ┌                                        ┐ 
                   beneficiary_not_allowed ┤■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 0.150517   
   balance_not_updated_after_bank_transfer ┤■■■■■■■■■■■■■■■■■■■■■■ 0.111409           
                     transfer_into_account ┤■■■■■■■■■■■■■■■■■■■ 0.0939535             
        transfer_not_received_by_recipient ┤■■■■■■■■■■■■■■■■■■ 0.091163               
            top_up_by_bank_transfer_charge ┤■■■■■■■■■■■■■■■■■■ 0.089306               
                           failed_transfer ┤■■■■■■■■■■■■■■■■■■ 0.0888322              
                           transfer_timing ┤■■■■■■■■■■■■■ 0.0641952                   
                      transfer_fee_charged ┤■■■■■■■ 0.0361131                         
                          pending_transfer ┤■■■■■ 0.0270795                           
                           receiving_money ┤■■■■■ 0.0252126                           
                         declined_transfer ┤■■■ 0.0164443                             
                           cancel_transfer ┤■■■ 0.0150444                             
                                           └                                        ┘ 

以下是更精炼的提示版本:“我试图给朋友转账,但失败了。” 由于不那么模糊的提示会导致较低的预测不确定性,因此它产生了较小的预测集。

refined_query = "I tried to transfer money to my friend, but it failed."
prediction_set(mach, refined_query)[2]
 Possible Intents              
                                           ┌                                        ┐ 
                           failed_transfer ┤■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■ 0.59042   
                   beneficiary_not_allowed ┤■■■■■■■ 0.139806                          
        transfer_not_received_by_recipient ┤■■ 0.0449783                              
   balance_not_updated_after_bank_transfer ┤■■ 0.037894                               
                         declined_transfer ┤■ 0.0232856                               
                     transfer_into_account ┤■ 0.0108771                               
                           cancel_transfer ┤ 0.00876369                               
                                           └                                        ┘ 

下面的视频展示了 REPL 基础聊天机器人在实际应用中的表现。你可以自己重现这个过程,并直接从你的终端运行机器人。为此,请查看我博客上的原始帖子以获取完整的源代码。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

REPL 基础的符合性聊天机器人的演示。由作者创建。

🌯 总结

这项工作是与 ING 的同事合作完成的,作为 ING Analytics 2023 实验周的一部分。我们的团队展示了符合性预测提供了对顶级-K意图分类的强大而有原则的替代方案。我们通过大众投票赢得了第一名。

当然,这里还有很多可以改进的地方。就大型语言模型而言,我们使用了一个较小的模型。在符合性预测方面,我们只关注了简单的归纳符合性分类。这是一个好的起点,但还有更高级的方法可用,这些方法已经在软件包中实现,并在竞赛中进行了研究。另一个我们没有考虑的方面是我们有许多结果类别,实际上可能希望实现类别条件覆盖。请关注未来的帖子了解更多内容。

如果你对在 Julia 中了解更多关于符合性预测的内容感兴趣,请查看代码库文档

🎉 JuliaCon 2023 即将到来,今年我将进行一场关于ConformalPrediction.jl讲座。请查看我的讲座详细信息,并浏览内容丰富的会议日程

🎓 参考文献

Casanueva, Iñigo, Tadas Temčinas, Daniela Gerz, Matthew Henderson, 和 Ivan Vulić. 2020. “使用双句子编码器的高效意图检测。” 第二届对话 AI 自然语言处理研讨会论文集 , 38–45. 在线:计算语言学协会. doi.org/10.18653/v1/2020.nlp4convai-1.5

Liu, Yinhan, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, 和 Veselin Stoyanov. 2019. “RoBERTa:一种稳健优化的 BERT 预训练方法。” arXiv. doi.org/10.48550/arXiv.1907.11692

💾 数据和模型

Banking77 数据集是从 HuggingFace 获取的。它在知识共享署名 4.0 国际许可协议(CC BY 4.0)下发布,由 PolyAI 策划,并由 Casanueva 等人(2020 年)最初发布。还要感谢 Manuel Romero 为 HuggingFace 贡献了经过微调的 DistilRoBERTa

最初发布于 https://www.paltmeyer.com 于 2023 年 7 月 5 日。

使用 OpenAI 和 FastAPI 构建记忆微服务的对话代理

原文:towardsdatascience.com/building-a-conversational-agent-with-memory-microservice-with-openai-and-fastapi-5d0102bc8df9?source=collection_archive---------1-----------------------#2023-08-17

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

充满记忆的对话,照片由 Juri GianfrancescoUnsplash 提供。

制作上下文感知的对话代理:深入探讨 OpenAI 和 FastAPI 的集成

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传 Cesar Flores

·

关注 发表在 Towards Data Science · 30 分钟阅读 · 2023 年 8 月 17 日

介绍

在本教程中,我们将探索使用 OpenAI 和 FastAPI 创建具有内存微服务的对话代理的过程。对话代理已成为各种应用程序中的关键组件,包括客户支持、虚拟助手和信息检索系统。然而,许多传统的聊天机器人实现缺乏在对话过程中保留上下文的能力,导致功能有限和令人沮丧的用户体验。这在遵循微服务架构构建代理服务时尤其具有挑战性。

GitHub 仓库的链接在文章底部。

动机

本教程的动机是解决传统聊天机器人实现的局限性,并创建一个具有内存微服务的对话代理,这在将代理部署到像 Kubernetes 这样的复杂环境中时尤为重要。在 Kubernetes 或类似的容器编排系统中,微服务经常经历重启、更新和扩展操作。在这些事件中,传统聊天机器人的对话状态将丢失,导致断裂的互动和糟糕的用户体验。

通过构建具有内存微服务的对话代理,我们可以确保在微服务重启或更新时,甚至在交互不连续的情况下,重要的对话上下文得以保留。这种状态的保存使代理能够无缝地继续之前的对话,保持连贯性,并提供更自然和个性化的用户体验。此外,这种方法符合现代应用开发的最佳实践,其中容器化的微服务通常与其他组件交互,使得内存微服务在这种分布式设置中成为对话代理架构中的有价值的补充。

我们将使用的技术栈

对于这个项目,我们将主要使用以下技术和工具:

  1. OpenAI GPT-3.5:我们将利用 OpenAI 的 GPT-3.5 语言模型,该模型能够执行各种自然语言处理任务,包括文本生成、对话管理和上下文保留。我们需要生成一个 OpenAI API 密钥,请确保访问此 URL 以管理您的密钥。

  2. FastAPI:FastAPI 将作为我们微服务的骨干,提供处理 HTTP 请求、管理对话状态和与 OpenAI API 集成的基础设施。FastAPI 非常适合用 Python 构建微服务。

开发周期

在本节中,我们将深入探讨构建具有内存微服务的对话代理的逐步过程。开发周期将包括:

  1. 环境设置:我们将创建一个虚拟环境并安装必要的依赖项,包括 OpenAI 的 Python 库和 FastAPI。

  2. 设计记忆微服务:我们将概述记忆微服务的架构和设计,该服务将负责存储和管理对话上下文。

  3. 集成 OpenAI:我们将把 OpenAI 的 GPT-3.5 模型集成到我们的应用中,并定义处理用户消息和生成响应的逻辑。

  4. 测试:我们将逐步测试我们的对话代理。

环境设置

对于这个设置,我们将使用以下结构来构建微服务。这对于在同一个项目下扩展其他服务非常方便,而且我个人喜欢这种结构。

├── Dockerfile <--- Container
├── requirements.txt <--- Libraries and Dependencies
├── setup.py <--- Build and distribute microservices as Python packages
└── src
    ├── agents <--- Name of your Microservice
    │   ├── __init__.py
    │   ├── api
    │   │   ├── __init__.py
    │   │   ├── routes.py
    │   │   └── schemas.py
    │   ├── crud.py
    │   ├── database.py
    │   ├── main.py
    │   ├── models.py
    │   └── processing.py
    └── agentsfwrk <--- Name of your Common Framework
        ├── __init__.py
        ├── integrations.py
        └── logger.py

我们需要在项目中创建一个名为src的文件夹,其中将包含服务的 Python 代码;在我们的例子中,agents包含与对话代理和 API 相关的所有代码,agentsfwrk是我们用于跨服务的通用框架。

Dockerfile包含构建镜像的指令,一旦代码准备好,requirements.txt包含我们项目中使用的库,setup.py包含构建和分发项目的指令。

目前,只需创建服务文件夹以及__init__.py文件,并将以下内容添加到项目根目录的requirements.txtsetup.py中,Dockerfile保持空白,我们将在部署周期部分回到它。

# Requirements.txt
fastapi==0.95.2
ipykernel==6.22.0
jupyter-bokeh==2.0.2
jupyterlab==3.6.3
openai==0.27.6
pandas==2.0.1
sqlalchemy-orm==1.2.10
sqlalchemy==2.0.15
uvicorn<0.22.0,>=0.21.1
# setup.py
from setuptools import find_packages, setup

setup(
    name = 'conversational-agents',
    version = '0.1',
    description = 'microservices for conversational agents',
    packages = find_packages('src'),
    package_dir = {'': 'src'},
    # This is optional btw
    author = 'XXX XXXX',
    author_email = 'XXXX@XXXXX.ai',
    maintainer = 'XXX XXXX',
    maintainer_email = 'XXXX@XXXXX.ai',
)

让我们激活虚拟环境,并在终端运行pip install -r requirements.txt。我们暂时不会运行 setup 文件,所以接下来进入下一部分。

设计通用框架

我们将设计我们的通用框架,以便在项目中构建的所有微服务中使用。这对小型项目来说不是严格必要的,但考虑到未来,你可以扩展它以使用多个 LLM 提供商,添加与自己数据交互的其他库(例如LangChainVoCode),以及其他通用功能,如语音和图像服务,而无需在每个微服务中实现它们。

创建文件夹和文件时请遵循agentsfwrk结构。每个文件及其描述如下:

└── agentsfwrk <--- Name of your Common Framework
    ├── __init__.py
    ├── integrations.py
    └── logger.py

日志记录器是一个非常基础的工具,用于设置通用日志模块,你可以按如下方式定义它:

import logging
import multiprocessing
import sys

APP_LOGGER_NAME = 'CaiApp'

def setup_applevel_logger(logger_name = APP_LOGGER_NAME, file_name = None):
    """
    Setup the logger for the application
    """
    logger = logging.getLogger(logger_name)
    logger.setLevel(logging.DEBUG)
    formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    sh = logging.StreamHandler(sys.stdout)
    sh.setFormatter(formatter)
    logger.handlers.clear()
    logger.addHandler(sh)
    if file_name:
        fh = logging.FileHandler(file_name)
        fh.setFormatter(formatter)
        logger.addHandler(fh)

    return logger

def get_multiprocessing_logger(file_name = None):
    """
    Setup the logger for the application for multiprocessing
    """
    logger = multiprocessing.get_logger()
    logger.setLevel(logging.DEBUG)
    formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")

    sh = logging.StreamHandler(sys.stdout)
    sh.setFormatter(formatter)

    if not len(logger.handlers):
        logger.addHandler(sh)

    if file_name:
        fh = logging.FileHandler(file_name)
        fh.setFormatter(formatter)
        logger.addHandler(fh)

    return logger

def get_logger(module_name, logger_name = None):
    """
    Get the logger for the module
    """
    return logging.getLogger(logger_name or APP_LOGGER_NAME).getChild(module_name)

接下来,我们的集成层通过集成模块完成。此文件充当微服务逻辑与 OpenAI 之间的中介,并设计为以统一的方式向我们的应用程序公开 LLM 提供商。在这里,我们可以实现处理异常、错误、重试和请求或响应超时的通用方法。我从一位非常优秀的经理那里学到,要始终在外部服务/API 和我们应用的内部世界之间放置一个集成层。

集成代码定义如下:

# integrations.py
# LLM provider common module
import json
import os
import time
from typing import Union

import openai
from openai.error import APIConnectionError, APIError, RateLimitError

import agentsfwrk.logger as logger

log = logger.get_logger(__name__)

openai.api_key = os.getenv('OPENAI_API_KEY')

class OpenAIIntegrationService:
    def __init__(
        self,
        context: Union[str, dict],
        instruction: Union[str, dict]
    ) -> None:

        self.context = context
        self.instructions = instruction

        if isinstance(self.context, dict):
            self.messages = []
            self.messages.append(self.context)

        elif isinstance(self.context, str):
            self.messages = self.instructions + self.context

    def get_models(self):
        return openai.Model.list()

    def add_chat_history(self, messages: list):
        """
        Adds chat history to the conversation.
        """
        self.messages += messages

    def answer_to_prompt(self, model: str, prompt: str, **kwargs):
        """
        Collects prompts from user, appends to messages from the same conversation
        and return responses from the gpt models.
        """
        # Preserve the messages in the conversation
        self.messages.append(
            {
                'role': 'user',
                'content': prompt
            }
        )

        retry_exceptions = (APIError, APIConnectionError, RateLimitError)
        for _ in range(3):
            try:
                response = openai.ChatCompletion.create(
                    model       = model,
                    messages    = self.messages,
                    **kwargs
                )
                break
            except retry_exceptions as e:
                if _ == 2:
                    log.error(f"Last attempt failed, Exception occurred: {e}.")
                    return {
                        "answer": "Sorry, I'm having technical issues."
                    }
                retry_time = getattr(e, 'retry_after', 3)
                log.error(f"Exception occurred: {e}. Retrying in {retry_time} seconds...")
                time.sleep(retry_time)

        response_message = response.choices[0].message["content"]
        response_data = {"answer": response_message}
        self.messages.append(
            {
                'role': 'assistant',
                'content': response_message
            }
        )

        return response_data

    def answer_to_simple_prompt(self, model: str, prompt: str, **kwargs) -> dict:
        """
        Collects context and appends a prompt from a user and return response from
        the gpt model given an instruction.
        This method only allows one message exchange.
        """

        messages = self.messages + f"\n<Client>: {prompt} \n"

        retry_exceptions = (APIError, APIConnectionError, RateLimitError)
        for _ in range(3):
            try:
                response = openai.Completion.create(
                    model = model,
                    prompt = messages,
                    **kwargs
                )
                break
            except retry_exceptions as e:
                if _ == 2:
                    log.error(f"Last attempt failed, Exception occurred: {e}.")
                    return {
                        "intent": False,
                        "answer": "Sorry, I'm having technical issues."
                    }
                retry_time = getattr(e, 'retry_after', 3)
                log.error(f"Exception occurred: {e}. Retrying in {retry_time} seconds...")
                time.sleep(retry_time)

        response_message = response.choices[0].text

        try:
            response_data = json.loads(response_message)
            answer_text = response_data.get('answer')
            if answer_text is not None:
                self.messages = self.messages + f"\n<Client>: {prompt} \n" + f"<Agent>: {answer_text} \n"
            else:
                raise ValueError("The response from the model is not valid.")
        except ValueError as e:
            log.error(f"Error occurred while parsing response: {e}")
            log.error(f"Prompt from the user: {prompt}")
            log.error(f"Response from the model: {response_message}")
            log.info("Returning a safe response to the user.")
            response_data = {
                "intent": False,
                "answer": response_message
            }

        return response_data

    def verify_end_conversation(self):
        """
        Verify if the conversation has ended by checking the last message from the user
        and the last message from the assistant.
        """
        pass

    def verify_goal_conversation(self, model: str, **kwargs):
        """
        Verify if the conversation has reached the goal by checking the conversation history.
        Format the response as specified in the instructions.
        """
        messages = self.messages.copy()
        messages.append(self.instructions)

        retry_exceptions = (APIError, APIConnectionError, RateLimitError)
        for _ in range(3):
            try:
                response = openai.ChatCompletion.create(
                    model       = model,
                    messages    = messages,
                    **kwargs
                )
                break
            except retry_exceptions as e:
                if _ == 2:
                    log.error(f"Last attempt failed, Exception occurred: {e}.")
                    raise
                retry_time = getattr(e, 'retry_after', 3)
                log.error(f"Exception occurred: {e}. Retrying in {retry_time} seconds...")
                time.sleep(retry_time)

        response_message = response.choices[0].message["content"]
        try:
            response_data = json.loads(response_message)
            if response_data.get('summary') is None:
                raise ValueError("The response from the model is not valid. Missing summary.")
        except ValueError as e:
            log.error(f"Error occurred while parsing response: {e}")
            log.error(f"Response from the model: {response_message}")
            log.info("Returning a safe response to the user.")
            raise

        return response_data

关于集成模块的一些说明:

  • OpenAI 密钥被定义为名为“OPENAI_API_KEY”的环境变量,我们应该下载这个密钥并在终端中定义它,或使用python-dotenv库。

  • 有两种方法可以与 GPT 模型集成,一种用于聊天端点(answer_to_prompt),另一种用于完成端点(answer_to_simple_prompt)。我们将专注于第一个的使用。

  • 有一种方法来检查对话的目标——verify_goal_conversation,它简单地遵循代理的指示并生成总结。

设计(内存)微服务

最佳练习是设计并绘制一个图表来可视化服务需要做的事情,包括参与者及其在与服务交互时的行动。我们从简单地描述我们的应用程序开始:

  • 我们的微服务是一个人工智能代理的提供者,这些代理在某一主题上是专家,预计会根据外部消息和后续提示进行对话。

  • 我们的代理可以进行多次对话,并且包含需要持久化的内存,这意味着它们必须能够保留对话历史记录,无论与代理交互的客户端会话如何。

  • 代理在创建时应接收清晰的指示,说明如何处理对话并在对话过程中做出相应响应。

  • 对于程序化集成,代理也应遵循预期的响应格式。

我们的设计如下图所示:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

对话代理设计——作者提供的图像

通过这个简单的图表,我们知道我们的微服务需要实现负责这些特定任务的方法:

  1. 代理的创建 & 指令的定义

  2. 对话启动器 & 对话历史记录的保存

  3. 与代理聊天

我们将按照顺序编写这些功能,在此之前我们将构建应用程序的骨架。

应用程序骨架

为了启动开发,我们首先构建 FastAPI 应用程序骨架。应用程序骨架包括基本组件,如主要应用程序脚本、数据库配置、处理脚本和路由模块。主要脚本作为应用程序的入口点,我们在此处设置 FastAPI 实例。

主要文件

在你的agents文件夹中创建/打开main.py文件并输入以下代码,该代码简单地定义了一个根端点。

from fastapi import FastAPI

from agentsfwrk.logger import setup_applevel_logger

log = setup_applevel_logger(file_name = 'agents.log')

app = FastAPI()

@app.get("/")
async def root():
    return {"message": "Hello there conversational ai user!"}

数据库配置

然后我们创建/打开名为database.py的数据库配置脚本,该脚本建立与本地数据库的连接,用于存储和检索对话上下文。我们将首先使用本地 SQLite 以简化操作,但可以根据你的环境尝试其他数据库。

from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

SQLALCHEMY_DATABASE_URL = "sqlite:///agents.db"

engine = create_engine(
    SQLALCHEMY_DATABASE_URL, connect_args = {"check_same_thread": False}
)
SessionLocal = sessionmaker(autocommit = False, autoflush = False, bind = engine)

Base = declarative_base()

API 路由

最后,我们定义处理传入 HTTP 请求的路由模块,涵盖处理用户交互的端点。让我们创建api文件夹,创建/打开routes.py文件,并粘贴以下代码。

from typing import List

from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session

import agents.api.schemas
import agents.models
from agents.database import SessionLocal, engine

from agentsfwrk import integrations, logger

log = logger.get_logger(__name__)

agents.models.Base.metadata.create_all(bind = engine)

# Router basic information
router = APIRouter(
    prefix = "/agents",
    tags = ["Chat"],
    responses = {404: {"description": "Not found"}}
)

# Dependency: Used to get the database in our endpoints.
def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

# Root endpoint for the router.
@router.get("/")
async def agents_root():
    return {"message": "Hello there conversational ai!"}

有了这个结构化的骨架,我们已经准备好开始编写我们设计的应用程序。

创建代理并分配指令

在本节中,我们将重点实现“创建代理”端点。此端点使用户能够启动新的对话并与代理互动,提供上下文和一组指令,以便代理在整个对话过程中遵循。我们将首先介绍两个数据模型:一个用于数据库,另一个用于 API。我们将使用Pydantic来创建数据模型。创建/打开schemas.py文件,并定义 Agent base、Agent Create 和 Agent 数据模型。

from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel

class AgentBase(BaseModel): # <-- Base model
    context: str # <-- Our agents context
    first_message: str # <-- Our agents will approach the users with a first message.
    response_shape: str # <-- The expected shape (for programatic communication) of the response of each agent's interaction with the user
    instructions: str # <-- Set of instructions that our agent should follow.

class AgentCreate(AgentBase): # <-- Creation data model
    pass

class Agent(AgentBase): # <-- Agent data model
    id: str
    timestamp: datetime = datetime.utcnow()

    class Config:
        orm_mode = True

agent 数据模型中的字段如下所述:

  • 上下文:这是代理的整体背景。

  • 首条消息:我们的代理旨在与用户开始对话。这可以简单到“你好,我可以帮你做什么?”或者类似“嗨,你请求一个代理来帮助你找到有关股票的信息,对吗?”。

  • 响应格式:该字段主要用于指定代理响应的输出格式,并应用于将 LLM 的文本输出转换为所需的格式,以便进行程序化通信。例如,我们可能希望指定我们的代理应该将响应包装在一个名为response的 JSON 格式中,即{'response': "string"}

  • 指令:该字段包含每个代理在整个对话过程中应遵循的指令和指南,例如“在每次交互中收集以下实体 [e1, e2, e3, …]”或“回复用户直到他不再对对话感兴趣”或“不要偏离主题,并在必要时将对话引导回主要目标”。

我们现在继续打开models.py文件,在其中编写属于 agent 实体的数据库表。

from sqlalchemy import Column, ForeignKey, String, DateTime, JSON
from sqlalchemy.orm import relationship
from datetime import datetime

from agents.database import Base

class Agent(Base):
    __tablename__ = "agents"

    id          = Column(String, primary_key = True, index = True)
    timestamp   = Column(DateTime, default = datetime.utcnow)

    context            = Column(String, nullable = False)
    first_message      = Column(String, nullable = False)
    response_shape     = Column(JSON,   nullable = False)
    instructions       = Column(String, nullable = False)

这段代码与 Pydantic 模型非常相似,它定义了我们数据库中的代理表。

在我们有了两个数据模型后,我们准备好实现代理的创建。为此,我们将首先修改routes.py文件,添加端点:

@router.post("/create-agent", response_model = agents.api.schemas.Agent)
async def create_agent(campaign: agents.api.schemas.AgentCreate, db: Session = Depends(get_db)):
    """
    Create an agent
    """
    log.info(f"Creating agent")
    # db_agent = create_agent(db, agent)
    log.info(f"Agent created with id: {db_agent.id}")

    return db_agent

我们需要创建一个新函数,该函数接收来自请求的 Agent 对象,并将其保存到数据库中。为此,我们将创建/打开crud.py文件,该文件将包含所有与数据库的交互**(创建、读取、更新、删除)**。

# crud.py
import uuid
from sqlalchemy.orm import Session
from agents import models
from agents.api import schemas

def create_agent(db: Session, agent: schemas.AgentCreate):
    """
    Create an agent in the database
    """
    db_agent = models.Agent(
        id              = str(uuid.uuid4()),
        context         = agent.context,
        first_message   = agent.first_message,
        response_shape  = agent.response_shape,
        instructions    = agent.instructions
    )
    db.add(db_agent)
    db.commit()
    db.refresh(db_agent)

    return db_agent

创建完函数后,我们现在可以回到routes.py,导入crud模块,并在端点方法中使用它。

import agents.crud

@router.post("/create-agent", response_model = agents.api.schemas.Agent)
async def create_agent(agent: agents.api.schemas.AgentCreate, db: Session = Depends(get_db)):
    """
    Create an agent endpoint.
    """
    log.info(f"Creating agent: {agent.json()}")
    db_agent = agents.crud.create_agent(db, agent)
    log.info(f"Agent created with id: {db_agent.id}")

    return db_agent

现在让我们回到main.py文件,添加“agents”路由。修改

# main.py
from fastapi import FastAPI

from agents.api.routes import router as ai_agents # NOTE: <-- new addition
from agentsfwrk.logger import setup_applevel_logger

log = setup_applevel_logger(file_name = 'agents.log')

app = FastAPI()
app.include_router(router = ai_agents) # NOTE: <-- new addition

@app.get("/")
async def root():
    return {"message": "Hello there conversational ai user!"}

让我们测试一下这个功能。首先,我们需要将我们的服务安装为 Python 包,其次,在 8000 端口启动应用程序。

# Run from the root of the project.
$ pip install -e .
# Command to run the app.
$ uvicorn agents.main:app --host 0.0.0.0 --port 8000 --reload

访问 0.0.0.0:8000/docs,你将看到带有测试端点的 Swagger UI。提交你的负载并检查输出。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

create-agent 端点来自 Swagger UI — 图片由作者提供

我们将继续开发我们的应用程序,但测试第一个端点是进展的良好标志。

创建对话 & 保留对话历史

我们的下一步是允许用户与我们的代理进行交互。我们希望用户能够与特定的代理进行互动,因此我们需要传递代理的 ID 以及用户的第一次互动消息。让我们对 Agent 数据模型进行一些修改,通过引入 Conversation 实体,使每个代理能够进行多个对话。打开 schemas.py 文件并添加以下模型:

class ConversationBase(BaseModel): # <-- base of our conversations, they must belong to an agent
    agent_id: str

class ConversationCreate(ConversationBase): # <-- conversation creation object
    pass

class Conversation(ConversationBase): # <-- The conversation objects
    id: str
    timestamp: datetime = datetime.utcnow()

    class Config:
        orm_mode = True

class Agent(AgentBase): # <-- Agent data model
    id: str
    timestamp: datetime = datetime.utcnow()
    conversations: List[Conversation] = [] # <-- NOTE: we have added the conversation as a list of Conversations objects.

    class Config:
        orm_mode = True

请注意,我们已经修改了 Agent 数据模型,并添加了对话功能,以便每个代理可以根据我们的图表设计进行多个对话。

我们需要修改我们的数据库对象,并在数据库模型脚本中包含对话表。我们将打开 models.py 文件,并按如下方式修改代码:

# models.py

class Agent(Base):
    __tablename__ = "agents"

    id          = Column(String, primary_key = True, index = True)
    timestamp   = Column(DateTime, default = datetime.utcnow)

    context            = Column(String, nullable = False)
    first_message      = Column(String, nullable = False)
    response_shape     = Column(JSON,   nullable = False)
    instructions       = Column(String, nullable = False)

    conversations      = relationship("Conversation", back_populates = "agent") # <-- NOTE: We add the conversation relationship into the agents table

class Conversation(Base):
    __tablename__ = "conversations"

    id          = Column(String, primary_key = True, index = True)
    agent_id    = Column(String, ForeignKey("agents.id"))
    timestap    = Column(DateTime, default = datetime.utcnow)

    agent       = relationship("Agent", back_populates = "conversations") # <-- We add the relationship between the conversation and the agent

请注意我们在 agents 表中为每个代理添加了对话之间的关系,以及在 conversations 表中对话与代理之间的关系。

我们现在将创建一组 CRUD 函数,以通过它们的 ID 检索代理和对话,这将帮助我们制定创建对话和保留对话历史的过程。让我们打开 crud.py 文件并添加以下函数:

def get_agent(db: Session, agent_id: str):
    """
    Get an agent by its id
    """
    return db.query(models.Agent).filter(models.Agent.id == agent_id).first()

def get_conversation(db: Session, conversation_id: str):
    """
    Get a conversation by its id
    """
    return db.query(models.Conversation).filter(models.Conversation.id == conversation_id).first()

def create_conversation(db: Session, conversation: schemas.ConversationCreate):
    """
    Create a conversation
    """
    db_conversation = models.Conversation(
        id          = str(uuid.uuid4()),
        agent_id    = conversation.agent_id,
    )
    db.add(db_conversation)
    db.commit()
    db.refresh(db_conversation)

    return db_conversation

这些新函数将帮助我们在应用程序的正常工作流程中,现在我们可以通过 ID 获取代理,通过 ID 获取对话,并通过提供可选的 ID 和应持有对话的代理 ID 来创建对话。

我们可以继续创建一个创建对话的端点。打开 routes.py 并添加以下代码:

@router.post("/create-conversation", response_model = agents.api.schemas.Conversation)
async def create_conversation(conversation: agents.api.schemas.ConversationCreate, db: Session = Depends(get_db)):
    """
    Create a conversation linked to an agent
    """
    log.info(f"Creating conversation assigned to agent id: {conversation.agent_id}")
    db_conversation = agents.crud.create_conversation(db, conversation)
    log.info(f"Conversation created with id: {db_conversation.id}")

    return db_conversation

在这个方法准备好后,我们仍然离拥有实际的对话端点还差一步,我们将在下一节中进行回顾。

在初始化代理时,重要的是要做出区分,我们可以创建一个对话而不触发双向消息交换,另一种方式是当调用“与代理聊天”端点时触发对话的创建。这为在微服务外部组织工作流提供了一些灵活性,在某些情况下,你可能想初始化代理,提前启动与客户的对话,并随着消息的到来开始保留消息的历史记录。

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

create-conversation 端点来自 Swagger UI — 图片由作者提供

重要提示: 如果您按照本指南逐步操作,并且在此步骤中看到与数据库模式相关的错误,请注意,这是因为我们在每次修改模式时都未将迁移应用到数据库,因此请确保关闭应用程序(退出终端命令)并删除在运行时创建的agents.db文件。 您需要重新运行每个端点并记录 ID。

与代理人聊天

我们现在要介绍我们应用程序中的最后一个实体类型,即Message实体。 这个实体负责建模客户消息和代理消息之间的交互(消息的双向交换)。 我们还将添加用于定义端点响应结构的 API 数据模型。 让我们先创建数据模型和 API 响应类型; 打开schemas.py文件,并修改代码:

##########################################
# Internal schemas
##########################################
class MessageBase(BaseModel): # <-- Every message is composed by user/client message and the agent 
    user_message: str
    agent_message: str

class MessageCreate(MessageBase):
    pass

class Message(MessageBase): # <-- Data model for the Message entity
    id: str
    timestamp: datetime = datetime.utcnow()
    conversation_id: str

    class Config:
        orm_mode = True

##########################################
# API schemas
##########################################
class UserMessage(BaseModel):
    conversation_id: str
    message: str

class ChatAgentResponse(BaseModel):
    conversation_id: str
    response: str

现在我们必须在代表数据库中表的数据库模型脚本中添加数据模型。 打开models.py文件并修改如下:

# models.py

class Conversation(Base):
    __tablename__ = "conversations"

    id          = Column(String, primary_key = True, index = True)
    agent_id    = Column(String, ForeignKey("agents.id"))
    timestap    = Column(DateTime, default = datetime.utcnow)

    agent       = relationship("Agent", back_populates = "conversations")
    messages    = relationship("Message", back_populates = "conversation") # <-- We define the relationship between the conversation and the multiple messages in them.

class Message(Base):
    __tablename__ = "messages"

    id          = Column(String, primary_key = True, index = True)
    timestamp   = Column(DateTime, default = datetime.utcnow)

    user_message    = Column(String)
    agent_message   = Column(String)

    conversation_id = Column(String, ForeignKey("conversations.id")) # <-- A message belongs to a conversation
    conversation    = relationship("Conversation", back_populates = "messages") # <-- We define the relationship between the messages and the conversation.

请注意,我们已修改了Conversations表以定义消息与会话之间的关系,并创建了一个新表,表示应属于对话的交互(消息交换)。

现在我们将向数据库添加一个新的 CRUD 函数,以与数据库交互并为对话创建消息。 打开crud.py文件并添加以下函数:

def create_conversation_message(db: Session, message: schemas.MessageCreate, conversation_id: str):
    """
    Create a message for a conversation
    """
    db_message = models.Message(
        id              = str(uuid.uuid4()),
        user_message    = message.user_message,
        agent_message   = message.agent_message,
        conversation_id = conversation_id
    )
    db.add(db_message)
    db.commit()
    db.refresh(db_message)

    return db_message

现在我们准备构建最终和最有趣的端点,chat-agent端点。 打开routes.py文件,并按照代码进行操作,因为我们将在途中实施一些处理函数。

@router.post("/chat-agent", response_model = agents.api.schemas.ChatAgentResponse)
async def chat_completion(message: agents.api.schemas.UserMessage, db: Session = Depends(get_db)):
    """
    Get a response from the GPT model given a message from the client using the chat
    completion endpoint.

    The response is a json object with the following structure:
    ```

    {

        `"conversation_id": "string",

        `"response": "string"

    }

    ```py
    """
    log.info(f"User conversation id: {message.conversation_id}")
    log.info(f"User message: {message.message}")

    conversation = agents.crud.get_conversation(db, message.conversation_id)

    if not conversation:
        # If there are no conversations, we can choose to create one on the fly OR raise an exception.
        # Which ever you choose, make sure to uncomment when necessary.

        # Option 1:
        # conversation = agents.crud.create_conversation(db, message.conversation_id)

        # Option 2:
        return HTTPException(
            status_code = 404,
            detail = "Conversation not found. Please create conversation first."
        )

    log.info(f"Conversation id: {conversation.id}")

在端点的这一部分中,我们确保在对话不存在时创建或引发异常。 下一步是准备数据,将其通过我们的集成发送到 OpenAI,为此,我们将在processing.py文件中创建一组处理函数,这些函数将从 LLM 中制作上下文,第一条消息,说明和预期的响应形状。

# processing.py

import json

########################################
# Chat Properties
########################################
def craft_agent_chat_context(context: str) -> dict:
    """
    Craft the context for the agent to use for chat endpoints.
    """
    agent_chat_context = {
        "role": "system",
        "content": context
    }
    return agent_chat_context

def craft_agent_chat_first_message(content: str) -> dict:
    """
    Craft the first message for the agent to use for chat endpoints.
    """
    agent_chat_first_message = {
        "role": "assistant",
        "content": content
    }
    return agent_chat_first_message

def craft_agent_chat_instructions(instructions: str, response_shape: str) -> dict:
    """
    Craft the instructions for the agent to use for chat endpoints.
    """
    agent_instructions = {
        "role": "user",
        "content": instructions + f"\n\nFollow a RFC8259 compliant JSON with a shape of: {json.dumps(response_shape)} format without deviation."
    }
    return agent_instructions

注意最后一个函数期望在代理人创建过程中定义的response_shape,此输入将在对话过程中附加到 LLM,并指导代理人遵循指南并将响应作为 JSON 对象返回。

让我们返回routes.py文件并完成我们的端点实现:

# New imports from the processing module.
from agents.processing import (
  craft_agent_chat_context,
  craft_agent_chat_first_message,
  craft_agent_chat_instructions
)

@router.post("/chat-agent", response_model = agents.api.schemas.ChatAgentResponse)
async def chat_completion(message: agents.api.schemas.UserMessage, db: Session = Depends(get_db)):
    """
    Get a response from the GPT model given a message from the client using the chat
    completion endpoint.

    The response is a json object with the following structure:
    ```

    {

        `"conversation_id": "string",

        `"response": "string"

    }

    ```py
    """
    log.info(f"User conversation id: {message.conversation_id}")
    log.info(f"User message: {message.message}")

    conversation = agents.crud.get_conversation(db, message.conversation_id)

    if not conversation:
        # If there are no conversations, we can choose to create one on the fly OR raise an exception.
        # Which ever you choose, make sure to uncomment when necessary.

        # Option 1:
        # conversation = agents.crud.create_conversation(db, message.conversation_id)

        # Option 2:
        return HTTPException(
            status_code = 404,
            detail = "Conversation not found. Please create conversation first."
        )

    log.info(f"Conversation id: {conversation.id}")

    # NOTE: We are crafting the context first and passing the chat messages in a list
    # appending the first message (the approach from the agent) to it.
    context = craft_agent_chat_context(conversation.agent.context)
    chat_messages = [craft_agent_chat_first_message(conversation.agent.first_message)]

    # NOTE: Append to the conversation all messages until the last interaction from the agent
    # If there are no messages, then this has no effect.
    # Otherwise, we append each in order by timestamp (which makes logical sense).
    hist_messages = conversation.messages
    hist_messages.sort(key = lambda x: x.timestamp, reverse = False)
    if len(hist_messages) > 0:
        for mes in hist_messages:
            log.info(f"Conversation history message: {mes.user_message} | {mes.agent_message}")
            chat_messages.append(
                {
                    "role": "user",
                    "content": mes.user_message
                }
            )
            chat_messages.append(
                {
                    "role": "assistant",
                    "content": mes.agent_message
                }
            )
    # NOTE: We could control the conversation by simply adding
    # rules to the length of the history.
    if len(hist_messages) > 10:
        # Finish the conversation gracefully.
        log.info("Conversation history is too long, finishing conversation.")
        api_response = agents.api.schemas.ChatAgentResponse(
            conversation_id = message.conversation_id,
            response        = "This conversation is over, good bye."
        )
        return api_response

    # Send the message to the AI agent and get the response
    service = integrations.OpenAIIntegrationService(
        context = context,
        instruction = craft_agent_chat_instructions(
            conversation.agent.instructions,
            conversation.agent.response_shape
        )
    )
    service.add_chat_history(messages = chat_messages)

    response = service.answer_to_prompt(
        # We can test different OpenAI models.
        model               = "gpt-3.5-turbo",
        prompt              = message.message,
        # We can test different parameters too.
        temperature         = 0.5,
        max_tokens          = 1000,
        frequency_penalty   = 0.5,
        presence_penalty    = 0
    )

    log.info(f"Agent response: {response}")

    # Prepare response to the user
    api_response = agents.api.schemas.ChatAgentResponse(
        conversation_id = message.conversation_id,
        response        = response.get('answer')
    )

    # Save interaction to database
    db_message = agents.crud.create_conversation_message(
        db = db,
        conversation_id = conversation.id,
        message = agents.api.schemas.MessageCreate(
            user_message = message.message,
            agent_message = response.get('answer'),
        ),
    )
    log.info(f"Conversation message id {db_message.id} saved to database")

    return api_response

Voilà! 这是我们最终的端点实现,如果我们查看代码中添加的Notes,我们会发现这个过程非常简单:

  1. 我们确保在我们的数据库中存在对话(或者我们创建一个)

  2. 我们从数据库中制作上下文和指导代理人

  3. 我们通过获取代理人的对话历史来利用代理人的“记忆”

  4. 最后,我们通过 OpenAI 的 GPT-3.5 Turbo 模型请求代理的响应,并将响应返回给客户端。

本地测试我们的代理

现在我们准备测试微服务的完整工作流,我们将首先进入终端,输入 uvicorn agents.main:app — host 0.0.0.0 — port 8000 — reload 启动应用程序。接下来,我们将通过访问 0.0.0.0:8000/docs 进入 Swagger UI 并提交以下请求:

  • 创建代理:提供你想测试的有效负载。我将提交以下内容:
{
    "context": "You are a chef specializing in Mediterranean food that provides receipts with a maximum of simple 10 ingredients. The user can have many food preferences or ingredient preferences, and your job is always to analyze and guide them to use simple ingredients for the recipes you suggest and these should also be Mediterranean. The response should include detailed information on the recipe. The response should also include questions to the user when necessary. If you think your response may be inaccurate or vague, do not write it and answer with the exact text: `I don't have a response.`",
    "first_message": "Hello, I am your personal chef and cooking advisor and I am here to help you with your meal preferences and your cooking skills. What can I can do for you today?",
    "response_shape": "{'recipes': 'List of strings with the name of the recipes', 'ingredients': 'List of the ingredients used in the recipes', 'summary': 'String, summary of the conversation'}",
    "instructions": "Run through the conversation messages and discard any messages that are not relevant for cooking. Focus on extracting the recipes that were mentioned in the conversation and for each of them extract the list of ingredients. Make sure to provide a summary of the conversation when asked."
}
  • 创建对话:将对话分配给从上一个响应中获取的 agent_id
{
    "agent_id": "Replace with the UUID from the agent you just created."
} 
  • 让我们通过外发消息与我们的代理聊天:我们的代理将通过提问或发送消息来初始化对话。我们将通过回复消息来继续对话 — 我将使用以下消息:
{
  "conversation_id": "74485ccb-21e5-45bd-8008-001cbd6d3e37",
  "message": "Hey I am interested in cooking dinner for me and my friends tomorrow night, and I need your advice on what to cook for 5 people."
}

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

chat-agent 端点响应来自 Swagger UI — 图片由作者提供

我们的代理给我们回复了响应,我们可以通过自然的方式继续这个对话。

# Our question
{
  "conversation_id": "74485ccb-21e5-45bd-8008-001cbd6d3e37",
  "message": "There is someone that is allergic to peanuts and that's it. My only preference is seafood and anything that goes perfectly with white wine."
}
# Our agents response
{
  "conversation_id": "74485ccb-21e5-45bd-8008-001cbd6d3e37",
  "response": "Great! We can definitely work around the peanut allergy and incorporate seafood into the menu. How about we start with a refreshing appetizer of Grilled Shrimp Skewers with Lemon and Garlic? It's simple, delicious, and pairs well with white wine.\n\nHere's the recipe for Grilled Shrimp Skewers with Lemon and Garlic:\n\nIngredients:\n1\. 1 pound of large shrimp, peeled and deveined\n2\. 2 cloves of garlic, minced\n3\. 2 tablespoons of fresh lemon juice\n4\. 2 tablespoons of olive oil\n5\. Salt and pepper to taste\n6\. Wooden skewers, soaked in water for 30 minutes\n\nInstructions:\n1\. Preheat your grill to medium-high heat.\n2\. In a bowl, combine minced garlic, lemon juice, olive oil, salt, and pepper.\n3\. Thread the shrimp onto the soaked wooden skewers.\n4\. Brush the shrimp skewers with the garlic-lemon marinade on both sides.\n5\. Place the skewers on the preheated grill and cook for about 2-3 minutes per side until they turn pink and opaque.\n6\. Remove from the grill and serve hot.\n\nFor the main course, how about a Mediterranean-style Baked Salmon? It's flavorful, healthy, and pairs wonderfully with white wine.\n\nHere's the recipe for Mediterranean Baked Salmon:\n\nIngredients:\n1\. 1 pound of salmon fillets\n2\. 2 tablespoons of olive oil\n3\. 2 cloves of garlic, minced\n4\. 1 teaspoon of dried oregano\n5\. Juice of half a lemon\n6\. Salt and pepper to taste\n\nInstructions:\n1\. Preheat your oven to 375°F (190°C).\n2\. Place the salmon fillets on a baking sheet lined with parchment paper or aluminum foil.\n3\. In a small bowl, mix together olive oil, minced garlic, dried oregano, lemon juice, salt, and pepper.\n4\. Brush the mixture over the salmon fillets, ensuring they are evenly coated.\n5\. Bake the salmon in the preheated oven for about 12-15 minutes or until it flakes easily with a fork.\n6\. Remove from the oven and let it rest for a few minutes before serving.\n\nTo complement the meal, you can serve a side of Roasted Mediterranean Vegetables. It's a colorful and nutritious addition to your dinner.\n\nHere's the recipe for Roasted Mediterranean Vegetables:\n\nIngredients:\n1\. 1 red bell pepper, sliced\n2\. 1 yellow bell pepper, sliced\n3\. 1 zucchini, sliced\n4\. 1 eggplant, diced\n5\. 1 red onion, sliced\n6\. 2 tablespoons of olive oil\n7\. Salt and pepper to taste\n8\. Fresh herbs like basil or parsley (optional)\n\nInstructions:\n1\. Preheat your oven to 425°F (220°C).\n2\. In a large bowl, toss together all the vegetables with olive oil, salt, and pepper until well coated.\n3\. Spread the vegetables in a single layer on a baking sheet.\n4\. Roast in the preheated oven for about 20-25 minutes or until they are tender and slightly caramelized.\n5\. Remove from the oven and sprinkle with fresh herbs if desired.\n\nFor dessert, you can keep it simple yet satisfying with a classic Mediterranean dessert like Fresh Fruit Salad with Honey and Mint.\n\nHere's how you can prepare Fresh Fruit Salad with Honey and Mint:\n\nIngredients:\n1\. Assorted fresh fruits of your choice (such as berries, melons, grapes)\n2\. Honey for drizzling\n3\. Fresh mint leaves for garnish\n\nInstructions:\n1\. Wash and cut the fresh fruits into bite-sized pieces.\n2\. Combine all the fruits in a bowl.\n3\. Drizzle honey over the fruit salad and gently toss to coat.\n4\. Garnish with fresh mint leaves.\n5\. Serve chilled.\n\nI hope you and your friends enjoy this Mediterranean dinner menu! If you have any other questions or need further assistance, feel free to ask."
}
# Our question - Note that we are not mentioning the allergy, as we expect the agent to remember we are talking about an allergy to peanuts.
{
  "conversation_id": "74485ccb-21e5-45bd-8008-001cbd6d3e37",
  "message": "Is the recipe ok for my friend who is allergic?"
}
# Our agents response - Note the mention of these two indices.
{
  "conversation_id": "74485ccb-21e5-45bd-8008-001cbd6d3e37",
  "response": "Yes, the Grilled Shrimp Skewers with Lemon and Garlic recipe should be safe for your friend with a peanut allergy. However, it's always important to double-check the ingredients you use to ensure they are free from any potential allergens or cross-contamination."
}

继续尝试代码和新的代理。在下一部分,我将重点介绍服务的部署。

部署周期

我们将在云的容器环境中部署应用程序,例如 Kubernetes、Azure Container Service 或 AWS Elastic Container Service。在这里,我们创建一个 docker 镜像并上传代码,以便在这些环境中的一个中运行,继续打开我们一开始创建的 Dockerfile,并粘贴以下代码:

# Dockerfile
FROM python:3.10-slim-bullseye

# Set the working directory
WORKDIR /app

# Copy the project files to the container
COPY . .

# Install the package using setup.py
RUN pip install -e .

# Install dependencies
RUN pip install pip -U && \
    pip install --no-cache-dir -r requirements.txt

# Set the environment variable
ARG OPENAI_API_KEY
ENV OPENAI_API_KEY=$OPENAI_API_KEY

# Expose the necessary ports
EXPOSE 8000

# Run the application
# CMD ["uvicorn", "agents.main:app", "--host", "0.0.0.0", "--port", "8000"]

Dockerfile 安装应用程序,然后通过 CMD 运行它,但 CMD 被注释掉了。如果你想作为独立应用本地运行,应该取消注释该命令,但对于 Kubernetes 等其他服务,这在定义部署或清单中的 pods 时已经定义。

构建镜像,等待构建完成,然后通过运行下面的运行命令进行测试:

# Build the image
$ docker build - build-arg OPENAI_API_KEY=<Replace with your OpenAI Key> -t agents-app .
# Run the container with the command from the agents app (Use -d flag for the detached run).
$ docker run -p 8000:8000 agents-app uvicorn agents.main:app --host 0.0.0.0 --port 8000
# Output
INFO:     Started server process [1]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
INFO:     172.17.0.1:41766 - "GET / HTTP/1.1" 200 OK
INFO:     172.17.0.1:41766 - "GET /favicon.ico HTTP/1.1" 404 Not Found
INFO:     172.17.0.1:41770 - "GET /docs HTTP/1.1" 200 OK
INFO:     172.17.0.1:41770 - "GET /openapi.json HTTP/1.1" 200 OK

太好了,你准备好在你的部署环境中开始使用应用程序了。

最后,我们将尝试将这个微服务与前端应用程序集成,通过内部调用端点来服务代理和对话,这是使用这种架构构建和交互服务的常见方式。

使用周期

我们可以以多种方式使用这个新服务,我将重点关注构建一个前端应用程序,该应用程序调用我们的代理端点,使用户能够通过 UI 进行交互。我们将使用 Streamlit 来实现,因为它是使用 Python 快速搭建前端的简单方法。

重要说明: 我在我们的代理服务中添加了额外的工具,你可以直接从代码库中复制这些工具。搜索 get_agents()get_conversations()get_messages() 这几个函数,分别在 crud.py 模块和 api/routes.py 路由中查找。

  • 安装 Streamlit 并将其添加到我们的 requirements.txt 文件中。
# Pin a version if you need
$ pip install streamlit==1.25.0
# Our requirements.txt (added streamlit)
$ cat requirements.txt
fastapi==0.95.2
ipykernel==6.22.0
jupyter-bokeh==2.0.2
jupyterlab==3.6.3
openai==0.27.6
pandas==2.0.1
sqlalchemy-orm==1.2.10
sqlalchemy==2.0.15
streamlit==1.25.0
uvicorn<0.22.0,>=0.21.1
  • 创建应用程序 首先在我们的 src 文件夹中创建一个名为 frontend 的文件夹。创建一个名为 main.py 的新文件,并放入以下代码。
import streamlit as st
import requests

API_URL = "http://0.0.0.0:8000/agents"  # We will use our local URL and port defined of our microservice for this example

def get_agents():
    """
    Get the list of available agents from the API
    """
    response = requests.get(API_URL + "/get-agents")
    if response.status_code == 200:
        agents = response.json()
        return agents

    return []

def get_conversations(agent_id: str):
    """
    Get the list of conversations for the agent with the given ID
    """
    response = requests.get(API_URL + "/get-conversations", params = {"agent_id": agent_id})
    if response.status_code == 200:
        conversations = response.json()
        return conversations

    return []

def get_messages(conversation_id: str):
    """
    Get the list of messages for the conversation with the given ID
    """
    response = requests.get(API_URL + "/get-messages", params = {"conversation_id": conversation_id})
    if response.status_code == 200:
        messages = response.json()
        return messages

    return []

def send_message(agent_id, message):
    """
    Send a message to the agent with the given ID
    """
    payload = {"conversation_id": agent_id, "message": message}
    response = requests.post(API_URL + "/chat-agent", json = payload)
    if response.status_code == 200:
        return response.json()

    return {"response": "Error"}

def main():
    st.set_page_config(page_title = "🤗💬 AIChat")

    with st.sidebar:
        st.title("Conversational Agent Chat")

        # Dropdown to select agent
        agents = get_agents()
        agent_ids = [agent["id"] for agent in agents]
        selected_agent = st.selectbox("Select an Agent:", agent_ids)

        for agent in agents:
            if agent["id"] == selected_agent:
                selected_agent_context = agent["context"]
                selected_agent_first_message = agent["first_message"]

        # Dropdown to select conversation
        conversations = get_conversations(selected_agent)
        conversation_ids = [conversation["id"] for conversation in conversations]
        selected_conversation = st.selectbox("Select a Conversation:", conversation_ids)

        if selected_conversation is None:
            st.write("Please select a conversation from the dropdown.")
        else:
            st.write(f"**Selected Agent**: {selected_agent}")
            st.write(f"**Selected Conversation**: {selected_conversation}")

    # Display chat messages
    st.title("Chat")
    st.write("This is a chat interface for the selected agent and conversation. You can send messages to the agent and see its responses.")
    st.write(f"**Agent Context**: {selected_agent_context}")

    messages = get_messages(selected_conversation)
    with st.chat_message("assistant"):
        st.write(selected_agent_first_message)

    for message in messages:
        with st.chat_message("user"):
            st.write(message["user_message"])
        with st.chat_message("assistant"):
            st.write(message["agent_message"])

    # User-provided prompt
    if prompt := st.chat_input("Send a message:"):
        with st.chat_message("user"):
            st.write(prompt)
        with st.spinner("Thinking..."):
            response = send_message(selected_conversation, prompt)
            with st.chat_message("assistant"):
                st.write(response["response"])

if __name__ == "__main__":
    main()

以下代码通过 API 调用连接到我们的代理微服务,并允许用户选择代理和对话,与代理聊天,类似于 ChatGPT 提供的功能。让我们通过打开另一个终端来运行这个应用程序(确保你的代理微服务在 8000 端口上运行),然后输入 $ streamlit run src/frontend/main.py,你就可以开始了!

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

AI 聊天 Streamlit 应用程序 — 作者提供的图片

未来改进和总结

未来改进

有几个令人兴奋的机会可以通过引入记忆微服务来增强我们的对话代理。这些改进引入了先进的功能,可以延长用户交互的时间,并扩展我们应用程序或整体系统的范围。

  1. 增强的错误处理: 为了确保对话的稳健性和可靠性,我们可以实现代码来优雅地处理意外的用户输入、API 失败——处理 OpenAI 或其他服务的问题,以及在实时交互中可能出现的潜在问题。

  2. 集成缓冲区和对话总结: 由 LangChain 框架实现的缓冲区集成,有可能优化令牌管理,使对话能够在更长的时间内进行而不会遇到令牌限制。此外,集成对话总结可以让用户回顾正在进行的讨论,帮助保持上下文,并改善整体用户体验。请注意代理指令和响应形状,以便在我们的代码中轻松扩展此功能

  3. 数据感知应用: 我们可以通过将我们的代理模型连接到其他数据源,例如内部数据库,来创建具有独特内部知识的代理。这涉及到训练或集成能够理解和响应基于对组织独特数据和信息理解的复杂查询的模型——请查看 LangChain 的数据连接 模块。

  4. 模型多样化: 虽然我们只使用了 OpenAI 的 GPT-3.5 模型,但语言模型提供商的格局正在迅速扩展。测试其他提供商的模型可以进行比较分析,揭示优缺点,并使我们能够选择最适合特定用例的模型——尝试不同的 LLM 集成,例如 HuggingFaceCohereGoogle’s 等。

结论

我们开发了一个微服务,提供由 OpenAI GPT 模型驱动的智能代理,并证明了这些代理可以携带存储在客户端会话之外的记忆。通过采用这种架构,我们解锁了无限的可能性。从上下文感知对话到与复杂语言模型的无缝集成,我们的技术栈已经能够为我们的产品提供新功能。

这种实现及其实际好处表明,使用 AI 的关键在于拥有合适的工具和方法。AI 驱动的代理不仅仅是关于提示工程,还在于我们如何构建工具并更有效地与它们互动,提供个性化体验,并以 AI 和软件工程所能提供的精细和精准处理复杂任务。因此,无论你是在构建客户支持系统、销售虚拟助手、个人厨师还是其他全新事物,请记住,旅程始于一段代码和丰富的想象力——可能性是无限的。

本文的完整代码在 GitHub 上——你可以在 LinkedIn上找到我,欢迎随时联系!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值