【全栈实战】大模型自学:从入门到实战打怪升级,百万字总结(三)

在这里插入图片描述
😊你好,我是小航,一个正在变秃、变强的文艺倾年。
🔔本栏讲解【全栈实战】大模型自学:从入门到实战打怪升级。
🔔专栏持续更新,适合人群:本科生、研究生、大模型爱好者,期待与你一同探索、学习、进步,一起卷起来叭!

智能对话(第一周)

  • 分类:
    • 交互方式
      • 文本对话系统:用户使用文本输入与机器人进行交互,是单一的NLP问题,使用场景包括网页客服机器人、APP客服机器人等在这里插入图片描述
      • 语音对话系统:分为内呼和外呼机器人,用户通过语音直接与机器人进行交流,涉及ASR、NLP及TTS系统,使用场景包括快递/银行官/运营商客服、电话推销、电话回访等在这里插入图片描述
        在这里插入图片描述
        在这里插入图片描述
    • 对话目的
      • 任务型对话:任务型对话,又称多轮对话系统,在预设流程中执行固定对话流程
      • 问答型对话:通过对知识库的检索获取问答结果,例如FAQ系统和KBQA系统
      • 闲聊对话:基于知识库检索或者生成模型进行闲聊式对话
    • 应用场景
      • 多轮对话:多轮对话系统由DM和NLU系统构成,针对不同任务定制对话流程
      • KBQA:基于知识图谱对外提供可靠的问答服务
      • FAQ:基于信息检索对外提供基础问答服务
      • 生成式对话:基于文本生成能力对外提供问答服务

多轮对话

  • 定义:多轮对话指根据上下文内容,进行连续的、以达到解决某一类特定任务为目的的对话。
  • 多轮对话系统参考历史信息的能力,在不同框架下实现方式:
    • 对于传统多轮对话系统,通过用户query驱动预设DM流程的方式实现。
    • 对于LLM,可以通过将前N轮对话内容与当前query组合传入模型的方式实现。多轮对话系统主要用于以下场景:
    • 机器人替代或部分替代人工完成流程明确的沟通过程场景。
    • 由LLM实现的开放域多轮对话场景。

传统多轮对话系统:

  • 核心模块:
    • 对话定制模块,用于根据多轮对话场景定制对话流程,生成对话流程文件。大部分为界面化拖拽方式,由公司自研或者使用第三方DMN工具,也可以以SOP脚本方式实现对话定制。
      • 公司自研DMN系统
        界面化的多轮对话定制系统,采用拖拽和连接的方式构建多轮对话逻辑,提供丰富的节点类型和功能选择,基于公司内部的NLU系统方便的完成节点跳转条件配置等工作
      • 第三方DMN软件
        第三方DMN软件,例如Camunda Modeler等
      • SOP脚本
        采用文本形式组织多轮对话流程,一行为一个节点,每行定义满足条件、禁止条件、跳转节点、播报话术等信息
    • 对话管理模块,即Dialogue Manager,提供对话引擎、对话状态管理等功能,用于驱动定制好的对话流程。
      • DM引擎
        加载多轮对话文件,在内存中建立图结构、树结构或其他形式的流程结构,在新query进入后,驱动流程运行
      • 上下文状态维护与更新
        维护多路对话的状态(session),包括历史轮次信息、各变量中间状态的维护与更新
      • DM工程框架
        DM系统的整体执行逻辑,包括消息的接收处理,对DM系统内部模块的调用和结果信息整合,是DM系统的外部工程框架
    • NLU模块,提供意图识别、槽位提取和NLG功能。
      • 意图解析
        对用户query进行意图解析,通过意图决定DM流程走向,一般采用规则系统与多标签文本分类任务相结合的方式
      • 槽位填充
        对用户query中的槽位信息进行提取,例如,姓名/电话号码/物品名称,一般采用规则系统与NER相结合的方式
      • NLG
        根据意图进入新节点后,返回新节点配置的NLG回复话术

知识图谱

  • 介绍:知识图谱是一种用图模型描述知识建模世界万物之间的关联关系的技术方法。
  • 背景:知识图谱是Google于2012年提出,最早的应用是提升搜索引擎的能力。随后,知识图谱在辅助智能问答、自然语言理解、大数据分析、推荐计算、物联网设备互联、可解释性人工智能等多个方面展现出丰富的应用价值。
  • 组成:知识图谱由节点和边组成
    • 节点可以是实体,如一个人、一本书等,或是抽象的概念,如人工智能、知识图谱等。
    • 边可以是实体的属性,如姓名、书名,或是实体之间的关系,如朋友、配偶。
  • 知识抽取
    • 实体抽取:从文本中检测出命名实体,并将其分类到预定义的类别中,例如人物、组织、地点、时间等
    • 关系抽取:从文本中识别抽取实体及实体之间的关系,即SPO三元组,例如父子关系、上下级关系等
    • 事件抽取:识别文本中关于事件的信息,并以结构化的形式呈现。例如,某事件发生的地点、时间等
  • 知识表示
    • 介绍:知识表示就是对知识的一种描述,或者说是对知识结构的一组约定,一种计算机可以接受的用于描述知识的数据结构。本质上,知识图谱是一种揭示实体之间关系的语义网络,可以对现实世界的事物及其相互关系进行形式化地描述。在这里插入图片描述
    • RDF:图数据库的一种描述方式,或者说是一种使用协议。它以"三元组"(triple)的方式,描述事物与事物之间的直接关系。在这里插入图片描述
      • 三元组:RDF的核心概念,指的是两个事物和它们之间的关系,在语法上呈现为"主语+谓语+宾语",也就是SPO三元组。
      • 要求:谓语(即事物之间的关系)必须有明确定义。如果谓语是给定的,就可以用主语去查询宾语,或者用宾语去查询主语。
      • SPARQL:RDF数据库的查询语言,跟SQL的语法很像。
        • 核心思想:根据给定的谓语动词,从三元组提取符合条件的主语或宾语。
        • 示例:在这里插入图片描述
      • 属性图:属性图由节点集和边集组成。目前被图数据库业界采纳最广的一种图数据模型。
        • 性质:
          • 每个节点具有唯一的id;
          • 每个节点具有若干条出边;
          • 每个节点具有若干条入边;
          • 每个节点具有一组属性,每个属性是一个键值对;
          • 每条边具有唯一的id;
          • 每条边具有一个头节点;
          • 每条边具有一个尾节点;
          • 每条边具有一个标签,表示联系;
          • 每条边具有一组属性,每个属性是一个键值对。
        • 示例:在这里插入图片描述
      • 知识存储:
        • 背景:传统关系数据库无法有效适应知识图谱的图数据模型。
        • 数据库:
          • 负责存储RDF图(RDF graph)数据的三元组库(Triple Store),如DBpedia,应用较少。
            • DBpedia:在这里插入图片描述
          • 管理属性图(Property Graph)的图数据库(Graph Database),如Neo4j、Nebula。【主流】
            • Neo4j:在这里插入图片描述
            • Nebula:在这里插入图片描述
              • 定义:一款开源的分布式图数据库,擅长处理千亿个顶点和万亿条边的超大规模数据集。提供高吞吐量、低延时的读写能力,内置ACL机制和用户鉴权,为用户提供安全的数据库访问方式。
              • 服务:Graph服务、Meta服务和Storage服务,是一种存储与计算分离的架构。Graph服务负责处理计算请求,Storage服务负责数据存储,Meta服务负责数据管理。在这里插入图片描述
                • Storage服务的层次结构:在这里插入图片描述
                  • Storage interface:Storage服务的最上层,定义了一系列和图相关的API
                  • Consensus:Storage服务的中间层,实现了Multi Group Raft,保证强一致性和高可用性;
                  • Store Engine:Storage服务的最底层,是一个单机版本地存储引擎RocksDB(见结尾补充内容)。
              • 实例:一个NebulaGraph实例由一个或多个图空间组成。每个图空间都是物理隔离的,用户可以在同一个实例中使用不同的图空间存储不同的数据集。[https://docs.nebula-graph.com.cn/2.6.1/1.introduction/1.what-is-nebula-graph/]在这里插入图片描述
              • 能力:
                • 高吞吐低时延;
                • 支持线性扩容缩容;
                • 兼容OpenCypher及多种工具;
                • 支持数据备份即快速恢复。
      • 架构:
        • 开放域知识图谱架构:在这里插入图片描述
        • 垂域知识图谱架构:在这里插入图片描述
      • 搭建流程示例:
        • 需求:搭建饮食知识图谱
          • 各类型食物(蔬菜、肉类、水果、坚果)的营养成分(热量、脂肪含量、蛋白质含量、碳水含量);(数仓数据)
          • 各类菜品(素菜、荤菜、荤素搭配、主食)的食材、调料、营养成分、特性(例如低油/低脂/低糖)、烹饪方法、菜系、口味(苦辣酸甜);(无数据)
          • 一日三餐的菜品搭配;(有菜品科学搭配方案)
          • 不同用户需求(减脂/增肌)的菜品搭配;(有菜品科学搭配方案)
          • 食物、食材、调料、菜品、用餐类型之间的明确关系;(无数据)
        • 数据来源:在这里插入图片描述
        • 图谱建模:在这里插入图片描述

RocksDB补充:

  • 介绍:一个高性能、可扩展、嵌入式、持久化、可靠、易用和可定制的键值存储库
  • 数据结构:RocksDB采用LSM树数据结构,支持高吞吐量的写入和快速的范围查询,可被嵌入到应用程序中,实现持久化存储,支持水平扩展,可以在多台服务器上部署,实现集群化存储,具有高度的可靠性和稳定性,易于使用并可以根据需求进行定制和优化。它广泛应用于互联网公司和数据密集型应用中。在这里插入图片描述
    • LSM(Log-Structured Merge Tree):将所有的数据修改操作(如插入、更新、删除)都记录在一个顺序日志文件中,这个日志文件又称为写前日志(Write-Ahead Log,WAL)。顺序日志文件的好处是顺序写入,相比随机写入可以提高写入性能。
      • LSM层级:LSM树中的层级可以分为内存和磁盘两个部分
        • 内存层:内存层也被称为MemTable,是指存储在内存中的数据结构,用于缓存最新写入的数据。当数据写入时,先将其存储到MemTable中,然后再将MemTable中的数据刷写到磁盘中,生成一个新的磁盘文件。由于内存读写速度非常快,因此使用MemTable可以实现高吞吐量的写入操作。
        • 磁盘层:磁盘层是指存储在磁盘中的数据文件,可以分为多个层级。一般来说,LSM树中的磁盘层可以分为Level-0 ~ Level-N几个层级
          • Level-0:Level-0是最底层的磁盘层,存储的是从内存层刷写到磁盘中的文件。Level-0中的文件大小一般比较小,排序方式为按照写入顺序排序。由于数据写入的速度很快,因此Level-0中的文件数量也比较多。
          • Level-1:Level-1是Level-0的上一层,存储的是由多个Level-0文件合并而来的文件。Level-1中的文件大小一般比较大,排序方式为按照键值排序。由于Level-0中的文件数量比较多,因此Level-1中的文件数量也比较多。
          • Level-2及以上:Level-2及以上的磁盘层都是由多个更低层级的文件合并而来的文件,文件大小逐渐增大,排序方式也逐渐趋向于按照键值排序。由于每个层级的文件大小和排序方式不同,因此可以根据查询的需求,选择最适合的层级进行查询,从而提高查询效率。
    • sstable文件:
      • sstable文件由block组成,block也是文件读写的最小逻辑单位,当读取一个很小的key,其实会读取一个block到内存,然后查找数据。
      • 默认的block_size大小为4KB。每个sstable文件都会包含索引的block,用来加快查找。所以block_size越大,index就会越少,也会相应的节省内存和存储空间,降低空间放大率,但是会加剧读放大,因为读取一个key实际需要读取的文件大小随之增加了。
      • 在生产环境使用Nebula数据库时,数据压缩参数、最大读取文件数量和block参数是重要的内存优化参数。

Nebula图数据库搭建补充:

1.下载需要的安装包:

在这里插入图片描述
2.连接服务器,进入工作目录:

# 明确操作系统类型和版本
# Ubuntu
cat /etc/issue
# Ubuntu 20.04.4 LTS \n \l

# CentOS
cat /etc/redhat-release
# CentOS Linux release 7.7.1908 (Core)

# 准备工作目录
cd ~/autodl-tmp/artboy
mkdir nebula && mkdir nebula_info

# 转移数据盘
cd nebula
mv ~/autodl-tmp/nebula/nebula-* .

# ls
# nebula-console-2.6.0  nebula-graph-2.6.0.ubuntu1804.amd64.tar.gz  nebula-graph-studio-3.1.0.x86_64.tar.gz

mv ~/autodl-tmp/nebula/nvm-0.39.3.tar.gz .
# ls
# nebula-console-2.6.0  nebula-graph-studio-3.1.0.x86_64.tar.gz   nebula-graph-2.6.0.ubuntu1804.amd64.tar.gz  nvm-0.39.3.tar.gz

# 安装nebula-graph-2.6.0
tar -xvf nebula-graph-2.6.0.ubuntu1804.amd64.tar.gz
cd nebula-graph-2.6.0.ubuntu1804.amd64
# ls
# bin  etc  logs  pids  scripts  share

3.单机配置:

进⼊安装包的etc配置⽂件⽬录,将三个核⼼组件的配置⽂件nebula-meta.conf.default、nebula-storaged.conf.default、nebulagraphd.conf.default复制为nebula-meta.conf、nebula-storaged.conf、nebula-graphd.conf

# 修改配置
cd etc
mv nebula-graphd.conf.default nebula-graphd.conf &&  mv nebula-storaged.conf.default nebula-storaged.conf && mv nebula-metad.conf.default nebula-metad.conf
# ls
# nebula-graphd.conf     nebula-metad.conf.production     nebula-storaged-listener.conf.production nebula-graphd.conf.production  nebula-storaged.conf nebula-metad.conf    nebula-storaged.conf.production

单机就是执⾏⼀次上⾯的配置⽂件复制,集群就是每台集群复制⼀次

4.日志文件配置:
Nebula执行过程中会产生大量日志信息,如果不进行日志监听级别和存储路径的修改,⽇志会很快占⽤⼤量内存,导致服务器/目录被占满,对nebula-graphd.conf、nebula-metad.conf、nebula-storaged.conf进⾏如下修改:

修改nebula-graphd.conf

# 1.复制路径
pwd
# /root/autodl-tmp/artboy/nebula/nebula-graph-2.6.0.ubuntu1804.amd64/etc

vim nebula-graphd.conf

########## logging ##########
# The directory to host logging files
--log_dir=/root/autodl-tmp/artboy/nebula_info/graph_logs
# Log level, 0, 1, 2, 3 for INFO, WARNING, ERROR, FATAL respectively
--minloglevel=1

修改nebula-metad.conf

vim nebula-metad.conf

########## logging ##########
# The directory to host logging files
--log_dir=/root/autodl-tmp/artboy/nebula_info/meta_logs
# Log level, 0, 1, 2, 3 for INFO, WARNING, ERROR, FATAL respectively
--minloglevel=1

########## storage ##########
# Root data path, here should be only single path for metad
--data_path=/root/autodl-tmp/artboy/nebula_info/meta_data

修改nebula-storaged.conf

vim nebula-storaged.conf

########## logging ##########
# The directory to host logging files
--log_dir=/root/autodl-tmp/artboy/nebula_info/storage_logs
# Log level, 0, 1, 2, 3 for INFO, WARNING, ERROR, FATAL respectively
--minloglevel=1

########## Disk ##########
# Root data path. Split by comma. e.g. --data_path=/disk1/path1/,/disk2/path2/
# One path per Rocksdb instance.
--data_path=/root/autodl-tmp/artboy/nebula_info/storagea_data

Nebula Storage服务的最底层,是⼀个单机版本地存储引擎Rocksdb,可以通过对Rocksdb进⾏参数调整来优化Storage的内存使⽤:

############## rocksdb Options ##############
# rocksdb DBOptions in json, each name and value of option is a string, given as "option_name":"option_value" separated by comma
--rocksdb_db_options={"max_background_jobs":"8", "max_open_files":"50000"}
# rocksdb ColumnFamilyOptions in json, each name and value of option is string, given as "option_name":"option_value" separated by comma
#--rocksdb_column_family_options={"write_buffer_size":"67108864","max_write_buffer_number":"4","max_bytes_for_level base":"268435456"}
# rocksdb BlockBasedTableOptions in json, each name and value of option is string, given as "option_name":"option_value" separated by comma
--rocksdb_block_based_table_options={"block_size":"32768"}

优化效果:

在这里插入图片描述
5.Nebula启动:

进入执行脚本:

cd  ../scripts

# ls
# meta-transfer-tools.sh  nebula-graphd.service  nebula-metad.service  nebula.service  nebula-storaged.service  utils.sh

# 启动服务
./nebula.service start all

[WARN] The maximum files allowed to open might be too few: 1024
[INFO] Starting nebula-metad...
[INFO] Done
[INFO] Starting nebula-graphd...
[INFO] Done
[INFO] Starting nebula-storaged...
[INFO] Done

# 查看状态
./nebula.service status all

[WARN] The maximum files allowed to open might be too few: 1024
[INFO] nebula-metad(3ba41bd): Running as 2034, Listening on 9559
[INFO] nebula-graphd(3ba41bd): Running as 2101, Listening on 9669
[INFO] nebula-storaged(3ba41bd): Exited

一般我们使用lsof来查看端口占用情况

# 1.更新apt-get服务
apt-get update
# 2.安装lsof命令
apt-get install lsof
# 3. 查看端口占用情况
lsof -i:9669
COMMAND    PID USER   FD   TYPE    DEVICE SIZE/OFF NODE NAME
nebula-gr 2101 root  516u  IPv4 263043713      0t0  TCP *:9669 (LISTEN)

安装nebula-graph-studio

tar -xvf nebula-graph-studio-3.1.0.x86_64.tar.gz

启动nebula-graph-studio需要npm命令,autodl服务器默认没有npm命令,需要通过nvm安装node.js

# 1.解压
tar -xvf nvm-0.39.3.tar.gz
# 2.添加配置命令,对~/.bashrc进⾏如下修改:
# ls
# ~/autodl-tmp/artboy/nebula# ls
# nvm-0.39.3.tar.gz nvm-0.39.3

vim ~/.bashrc
# 底部那里添加以下内容
export NVM_DIR="/root/autodl-tmp/artboy/nebula/nvm-0.39.3"
[ -s "$NVM_DIR/nvm.sh" ] && \. "$NVM_DIR/nvm.sh"

# 3.source一下
source ~/.bashrc

# 4.测试nvm命令
nvm -v
# 0.39.3

安装nodejs和nrpm:

nvm install 16.17.0

启动nebula-graph-studio:

cd nebula-graph-studio
# ls
# DEPLOY.md  nebula-graph-studio  nebula-http-gateway

# 1.
cd nebula-http-gateway
nohup ./nebula-httpd &
# [1] 4241

# 2.
cd ../nebula-graph-studio
npm run start
# lsof -i:7001
# COMMAND  PID USER   FD   TYPE    DEVICE SIZE/OFF NODE NAME
# node    4288 root   23u  IPv4 263509160      0t0  TCP *:7001 (LISTEN)

6.本地访问:

1.点击自定义服务,选择Linux/Mac

在这里插入图片描述
粘贴,修改需要映射的端口号:

在这里插入图片描述
访问本地端口:

在这里插入图片描述
输入账号密码进行登录:

在这里插入图片描述

KBQA系统

  • 介绍:一种根据知识库知识,准确、简洁地回答自然语言问题的问答系统。
  • 实现方式:
    • 基于语义解析(Semantic Parsing)的方法:对问句进行句法/语法解析和信息提取,并将解析结果组合成可执行的逻辑表达式(如SPARQL),直接从图数据库中查询答案。
      • 步骤:将自然语言的问句解析成逻辑形式(Logic Form)
        • (1)问题解析:解析自然语言问题句法&语义(Pythia);在这里插入图片描述
        • (2)模板生成:通过规则系统将语义信息映射为SPARQL模板;在这里插入图片描述
        • (3)模板实例化:通过实体链接和关系链接将SPARQL模板中的slot进行填充得到完整的可查询的模板;
          • 实体词链接:SPARQL模板要想作用于一个RDF数据库,需要将模板中的字符串映射为RDF数据库中的实体词(比如:类型词、实例词和属性词等)
            • 类型词与实体词链接:对于待匹配的字符串s,从WordNet中获取近义词词典S(s),找到符合查找类型label(e)的所有实体e,利用字符串相似度计算实体e与每个近义词的相似度,选出字符串相似度最高的实体。在这里插入图片描述
              • trigram:三元分词,把句子从头到尾每三个字组成一个词语。
              • 编辑距离:两个字串之间,由一个转换成另一个所需的最少编辑操作次数。
              • 最大子串相似度:两个字符串之间最长的相同子字符串的长度。
            • 属性/关系词链接:由于属性词(谓词)可以有多重说法,所以属性词的链接相对复杂,这里采用BOA(Bootstrapping linked datA)框架通过bootstrap的方法挖掘出不同的自然语言说法到谓词的映射
              • (1)对于每一个谓词,通过知识库K我们可以得到很多满足I ( p ) = { ( x, y ) : ( x p y )∈K }的样例;
              • (2)对于每个样例{ x y },我们可以从语料库(如wiki等)找出x和y共现的句子;
              • (3)根据label(x) .* label(y)label(y) .* label(x)的正则从共现句中匹配出子串;
              • (4)对于这些子句,BOA会抽象出NLE表达式θ,形如?D?representation?R?或者?R?representation?D?,其中?D?和?R?为label(x)和label(y)的占位符,例如,?book? is abook of ?author?。NLE抽取过程会得到一个庞大集合(p, θ),称之为BOA patterns,每一个θ都代表p的一种潜在表示;
              • (5)对每一个挖掘出来的表达式θ,我们可以计算出其与某个谓词p的匹配得分,根据这个得分我们可以筛选出最合适的映射关系。
              • 评估指标:
                • 一个好的NLE θ能够覆盖 I ( p ) 中的多个元素,即θ的支持性support;(类似TF)【大白话:吃,覆盖多少句子合理(老虎吃肉,小明吃汉堡,…)】在这里插入图片描述
                • 一个好的 NLE θ 的占位符 ?D? 和 ?R? 应该只匹配 rdf:type 在 p 的 range 和 domain 限制范围内的实体 label,即 θ 的典型性 typicity; (主-宾类型适配度)【大白话:吃(小明吃汉堡、小明吃苹果);喜欢(小明喜欢苹果,小明喜欢电影);吃和喜欢跟的宾语类型不一样,相当于类型加了一层限制】在这里插入图片描述
                • 一个好的NLE θ应该专门用来表达p,也就是说它只能表征少数的p,即θ的特异性specificity。(类似IDF)【大白话:小明喜欢xx,小明害怕xx,喜欢和害怕能好多替换,专一性越弱,分子分母越接近,log越小在这里插入图片描述
          • SPARQL排序:通过模版实例化过程后,得到一批候选SPARQL,接下来需要排序并选择最优SPARQL
            • 实现方法:对SPARQL中填入的每一个链接词,计算如下的分数score(e)(链接词的相似性分数σ(e)与三元组的显著性分数φ(e)):【大白话:σ(e)的含义是我喜欢你的关系整个知识图谱出现过多少次;score:我和你是妻子。我在图谱的贡献 + 妻子在图谱的贡献在这里插入图片描述
        • (4)模板排序:因为自然语言的模糊性,一句话可能映射为多个SPARQL模板,所以会对多个模板进行排序;
        • (5)模板查询:用SPARQL模板从RDF数据查询获取结果。
      • 语义解析补充:
        • 论文的实验是在QALD5数据集(QALD5基准:包含两组50个关于DB-pedia的问题)上验证,通过SPARQL查询和答案注释,每个问题都采用准确和召回进行评估。
          • 评估结果:平均precision为0.61,平均recall为0.63,由此计算F值为0.62。
          • 核心问题:解决两种描述语言之间的不匹配问题,一种是数据库中的干净、规范化的本体论描述语言,另一种是自然语言中获取的查询描述语言,如何将这两种描述语言匹配起来,是KBQA的难点。
        • AMR:大模型出现之前,比较有潜力的方法,该方法旨在使用问题中传达的信息,直接从知识库中检索并排序答案在这里插入图片描述
    • 基于信息检索(Information Retrieval)的方法:先解析出问句的主实体,再从KG中查询出主实体关联的多个三元组,组成子图路径(也称多跳子图),之后分别对问句和子图路径编码、排序,返回分数最高的路径作为答案。在这里插入图片描述
      • (1)从问题中确定中心实体,并从知识库中提取出特定于问题的子图,理想情况下,该图应该包含所有与中心实体相关的实体和关系;
      • (2)通过一个问题表示模块,对输入的问题进行embedding,得到编码向量;
      • (3)通过候选答案表示模块,对候选子图进行embedding;
      • (4)对问题embedding和候选子图embedding进行相似度计算,选出目标答案。
      • 基于Bert:
        • 方法:在这里插入图片描述
          • 1.首先找到实体链接系统中连接主题实体e_topic和候选实体ei的所有路径(设置最大路径数并在数量超过阈值时应用下采样);
          • 2.然后通过在知识库KB中用实体名称替换节点和用关系名称替换边来构建每条路径的文本形式
          • 3.然后concatenate问题q和所有路径p1, …,pn生成输入样本:xi = [CLS]q[SEP]p1 [SEP]……pn[SEP];
          • 4.将样本提供给BERT并采用与[CLS] token对应的表示进行二分类(将这些路径视为主题实体e_topic和候选实体ei之间的事实,目标是使用BERT来预测假设“ei is the answer ofq”是否得到这些知识库KB事实的支持)。
        • 预训练任务:在这里插入图片描述
          • 1.Relation Extraction(RE):从句子中推断关系,基于大规模关系抽取开源数据集,生成了大量一跳([CLS]s[SEP]h, r, t[SEP])与两跳([CLS]s1 , s2 [SEP]h1 , r1 , t1 (h2 ), r2 ,t2 [SEP])的文本对训练数据,让模型学习自然语言与结构化文本间的关系。
          • 2.Relation Matching(RM):判断两个句子是否表达相同关系,为了让模型更好的捕捉到关系语义,我们基于关系抽取数据生成了大量文本对,拥有相同关系的文本互为正例,否则为负例。
          • 3.Relation Reasoning(RR):自监督方式从知识库构建数据,对缺失连接进行推理。为了让模型具备一定的知识推理能力,假设图谱中的(h, r, t)缺失,并利用其他间接关系来推理(h, r, t)是否成立,输入格式为:[CLS]h, r, t[SEP]p1 [SEP] . . . pn [SEP]。

FAQ

  • 介绍:FAQ系统基于问答库,采用文本匹配的方式召回候选问答对,排序后进行问答回复,提供一问一答式的问答体验。
  • 流程:
    • 知识库构建:通过内部数据及外部采集数据搭建QA知识库,后续FAQ问答基于知识库内容。在这里插入图片描述在这里插入图片描述
    • 召回策略实现:实现从FAQ库内召回候选问答簇的算法策略。
    • 匹配策略实现:实现对召回问答簇进行文本匹配的精排策略,用于确定候选的TopK个答案。
    • FAQ系统搭建:完成FAQ整体系统的搭建,串联FAQ系统中的各个模块。
  • 架构:
    • 1.0:在这里插入图片描述
    • 2.0:在这里插入图片描述
    • 3.0:在这里插入图片描述

生成式对话系统

在这里插入图片描述

实战(第二、三周)

Ollama

介绍

Ollama是一个专为在本地环境中运行和定制大型语言模型而设计的工具。它提供了一个简单而高效的接口,用于创建、运行和管理这些模型,同时还提供了一个丰富的预构建模型库,可以轻松集成到各种应用程序中。Ollama的目标是使大型语言模型的部署和交互变得简单,无论是对于开发者还是对于终端用户。

安装

Win 安装

官网下载:https://ollama.com/

验证安装:

ollama --version

在这里插入图片描述
windows 的安装默认不支持修改程序安装目录,

  • 默认安装后的目录:C:\Users\username\AppData\Local\Programs\Ollama
  • 默认安装的模型目录:C:\Users\username\ .ollama
  • 默认的配置文件目录:C:\Users\username\AppData\Local\Ollama

我们可以设置环境变量来调整保存model的地址:

在这里插入图片描述
在这里插入图片描述
新建系统变量(别忘了提前创建文件夹):

在这里插入图片描述

Linux 安装
# 方式一:
# 更新包列表
sudo apt-get update
# 安装ollama
sudo apt-get install ollama

# 方式二:
curl -fsSL https://ollama.com/install.sh | sh

# 验证安装
ollama --version
Docker 安装
docker pull ollama/ollama
docker run -d -p 3000:8080 --gpus=all -v ollama:/root/.ollama -v open-webui:/app/backend/data --name open-webui --restart always ollama/ollama

# 验证
curl localhost:3000

常用命令

1. 启动Ollama服务
ollama serve

2. 从模型文件创建模型
ollama create

3. 显示模型信息
ollama show

4. 运行模型
ollama run 模型名称

5. 从注册表中拉去模型
ollama pull 模型名称

6. 将模型推送到注册表
ollama push

7. 列出模型
ollama list

8. 复制模型
ollama cp

9. 删除模型
ollama rm 模型名称

10. 获取有关Ollama任何命令的帮助信息
ollama help

Agent

架构

在这里插入图片描述
FastChat是一个用于训练、服务和评估基于大型语言模型的聊天机器人的开放平台。

  • 提供SOTA模型的训练和评估代码;
  • 提供分布式多模型部署框架+ WebUI + OpenAI API;

FastChat可以帮助我们快速的进行大模型部署,并对外提供服务,可以直接复用OpenAI的API对本地部署模型进行访问。

环境准备

显卡:A40;GPU memory:48GB;
版本镜像:pytorch 1.11.0  python3.8  Ubuntu

代码准备

代码仓库:🔗Qwen-14B-Chat-Demo

模型:Qwen-14B-Chat

模型启动

模型下载:

conda create -n modelscope_env python=3.8.5
source activate modelscope_env
pip install modelscope
# cat qwen_14B_chat_download.py
python qwen_14B_chat_download.py

FastChat:

conda create -n fastchat_env python=3.8.5
source activate fastchat_env
# 安装 FastChat 包
bash fastchat_install.sh
# 安装 Qwen 依赖
bash requirements_install.sh
# 启动 Controller, Controller 启动后会占⽤21001端⼝
nohup bash controller_start.sh > cl_20240808.log &
# 启动 Worker 
# cat worker_start.sh
nohup bash worker_start.sh > wk_20240808.log &
# 启动 OpenAI 接⼝⽀持, OpenAI 接⼝⽀持服务启动后会占⽤8000端⼝
nohup bash openai_support.sh > oa_20240808.log &
# 本地机器建⽴端⼝映射
ssh -CNg -L 8000:127.0.0.1:8000 root@region-42.seetacloud.com -p 26980

Agent:

在这里插入图片描述

python agent_chat.py

训练框架

在这里插入图片描述

  • 通用问答能力:系统能够理解和回答各种广泛主题下的问题的能力。这种系统不局限于某个特定领域或行业,而是能够跨越多个领域,如历史、科学、技术、娱乐、体育等,提供准确和有用的回答。
  • 垂域问答能力:专注于某个特定领域或行业内的问答任务。这种系统针对某一领域的专业知识进行深度学习和优化,能够提供更加精确和专业的回答。垂域问答系统通常被应用于医疗、法律、金融、教育等需要高度专业知识的领域。

QLoar微调

环境准备

安装教程请移步大模型自学:从入门到实战打怪升级(一)

conda activate pytorch

python
>>> import torch
>>> torch.cuda.is_available()
>>> True

>>> print(torch.__version__)  #注意是双下划线
>>> 2.4.0
>>> exit()

代码准备

代码仓库:🔗qwen_qlora

在这里插入图片描述
上传代码至阿里云盘

模型来源

在这里插入图片描述
qwen_download.py

from modelscope.hub.snapshot_download import snapshot_download
# 使用Library Hub下载模型
model_dir = snapshot_download('qwen/Qwen-14B-Chat', cache_dir='/root/autodl-tmp/artboy/base_model/', revision='v1.0.8')

模型微调

🔗AutoDL:https://www.autodl.com/home

服务器选择:

显卡:A40;GPU memory:48GB;
版本镜像:pytorch 1.11.0  python3.8  Ubuntu

FP32:32Bits = 4Bytes;FP16:16Bits = 2Bytes
BatchSize = 1;SequenceLen = 1024;HiddenSize = 5120;AttentionHead = 40;Layer = 40;

模型参数大小:14B
- Model Weight:28GB
- Gradient:28GB(即 Model Weight)
- Optimizer State:168GB(即 Model Weight * 2 * 3)
	- 梯度指数平滑值:56GB(即 Model Weight * 2)
	- 梯度平方指数平滑值:56GB(即 Model Weight * 2)
	- 模型参数:56GB(即 Model Weight * 2)
- Activation:14.5GB【BatchSize * SequenceLen * HiddenSize * Layer * (34 + 5 * AttentionHead / HiddenSize)】
全参微调至少需要238.5GB;

模型参数大小:14B
- Model Weight:28GB
- Adapter weight:0.7GB(即 Origin Adapter weight * 2.5%)
- Gradient:0.7GB(即 Origin Gradient * 2.5%)
- Optimizer State:4.2GB(即 Origin Optimizer State * 2.5%)
- Activation:0.36GB(即 Origin Activation * 2.5%)
Loar微调至少需要47.96GB;

模型参数大小:14B
- Model Weight:7GB(4bit量化)
- Double Quantization:3.6GB(即 Origin Model Weight * 0.127)
- Adapter weight:0.7GB(即 Origin Adapter weight * 2.5%)
- Gradient:0.7GB(即 Origin Gradient * 2.5%)
- Optimizer State:4.2GB(即 Origin Optimizer State * 2.5%)
- Activation:0.36GB(即 Origin Activation * 2.5%)
QLoar微调至少需要16.56GB;

创建基本目录并移动代码:

mkdir -p /root/autodl-tmp/artboy /root/autodl-tmp/tmp
cd /root/autodl-tmp/artboy
mkdir -p finetune data base_model

cd finetune
mv ~/autodl-tmp/qwen_qloar/ .
ls

下载模型:

conda create -n modelscope_env python=3.8
source activate modelscope_env 
pip install modelscope
python qwen_download.py

安装依赖:

conda create -n qwen_env python=3.8
source activate qwen_env
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt
# 也可以编写一个安装脚本
# bash install_req.sh

修改训练参数:

如果是单卡训练:train_qwen_qlora.shCUDA_VISIBLE_DEVICES=0

CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 train_qlora.py --train_args_file config/qwen-14b-qlora.json

修改模型路径:config/qwen-14b-qlora.json中model_name_or_path

{
    "output_dir": "trained_models/Qwen-14B-Chat-Keywords-1118-match-1124",
    "model_name_or_path": "/root/autodl-tmp/artboy/base_model/Qwen-14B-Chat",
    "train_file": "./data/text_matching_data_train.jsonl",
    "num_train_epochs": 1,
    "per_device_train_batch_size": 8,
    "gradient_accumulation_steps": 4,
    "learning_rate": 5e-5,
    "max_seq_length": 1024,
    "logging_steps": 10,
    "save_steps": 500,
    "save_total_limit": 1,
    "lr_scheduler_type": "constant_with_warmup",
    "warmup_steps": 300,
    "lora_rank": 64,
    "lora_alpha": 16,
    "lora_dropout": 0.05,
    "gradient_checkpointing": true,
    "disable_tqdm": false,
    "optim": "paged_adamw_32bit",
    "seed": 42,
    "fp16": true,
    "report_to": "tensorboard",
    "dataloader_num_workers": 0,
    "save_strategy": "steps",
    "weight_decay": 0,
    "max_grad_norm": 0.3,
    "remove_unused_columns": false
}

启动:【日志命名规则:日期_时间.log

nohup bash train_qwen_qlora.sh > 202400806_1153.log &
tail -f 202400806_1153.log

开启新的窗口,关注显卡使用情况:【动态调整batch_size来合理利用GPU资源】

watch -n -1 nvidia-smi

模型评估

这里我们跑500个case测测准确率【正常应该在90%以上】,文件在data/text_matching_data_test_result.csv

tail data/text_matching_data_test_result.csv

句子1:今日也是爱你的一天。句子2:今天一天都是爱着你的歌词。判断这两个句子的意思是否相同:	不相同	不相同
判断下面两个句子是否表达了相同的语义:。文本1:马上关机马上关机。文本2:马上去了	答案:不相同	答案:不相同
句子1:单位怎么合法的办理给员工调动岗位和工资呢?。句子2:不接受公司调岗的劳动争议应该怎么办。判断这两个句子的意思是否相同:	不相同	不相同
下面两个句子是否表达了相同的意思:。文本1:怎么在淘宝上投诉卖家。文本2:如何投诉淘宝卖家?。答案:	相同	相同
下面两句话的意思是否相同:。文本1:这是什么时候的旗帜?。文本2:这是什么旗帜。选项:相似,不相似。答案:	不相同	不相同
下面两个句子表达的意思相同吗:。句子1:保证人可以以自己没钱,不履行担保责任么?。句子2:借款人没有钱还担保人也还不上担保人结果怎么样。选项:相似,不相似。答案:	不相同	不相同
文本1:我因为工作经常弯腰得了腰间盘突出的病,这属于工伤吗?。文本2:腰间盘突出可以作为职业病的鉴定吗?单位该如何赔偿?。这两个句子是否表达了相同的意思:	不相同	不相同
下面两个句子表达的意思相同吗?。文本1:如何使用level2软件。文本2:l2购买过以后在哪里打开使用。答案:	相同	相同
下面两个句子表达的意思相同吗:。句子1:如何查看自己购买的基金。句子2:我想查看自己的帐号怎么看。选项:相似,不相似。答案:	不相同	不相同
文本1:阿莫西林哺乳期可以吃吗。文本2:哺乳期能吃阿莫西林么?。这两个句子表达了相同的语义吗:	相同	相同
下面句子是否表示了相同的语义:。文本1:公民、法人对行政机关作出的行政处罚不服的,是否可以申请行政复议?。文本2:对行政处罚不服的公民、法人,能否申请行政复议?。选项:相似,不相似。答案:	相同	相同

执行评估脚本:

python model_evaluation.py

Deepspeed微调

环境准备

安装教程请移步大模型自学:从入门到实战打怪升级(一)

conda activate pytorch

python
>>> import torch
>>> torch.cuda.is_available()
>>> True

>>> print(torch.__version__)  #注意是双下划线
>>> 2.4.0
>>> exit()

代码准备

代码仓库:🔗deepspeed_baichuan2_7B_base

在这里插入图片描述

模型来源

在这里插入图片描述
llm_download/baichuan2_7B_base_download.py

from modelscope.hub.snapshot_download import snapshot_download
# 使用Library Hub下载模型
model_dir = snapshot_download('baichuan-inc/Baichuan2-7B-Base', cache_dir='/hy-tmp/autodl-tmp/artboy/base_model/', revision='v1.0.2')

模型微调

🔗恒源云:https://gpushare.com

服务器选择:

显卡:A800 * 2;GPU memory:80GB * 2;
版本镜像:pytorch 1.11.0  python3.8  Ubuntu

FP32:32Bits = 4Bytes;FP16:16Bits = 2Bytes
BatchSize = 1;SequenceLen = 1024;HiddenSize = 5120;AttentionHead = 40;Layer = 40;

模型参数大小:7B
- Model Weight:14GB
- Gradient:14GB(即 Model Weight)
- Optimizer State:84GB(即 Model Weight * 2 * 3)
	- 梯度指数平滑值:28GB(即 Model Weight * 2)
	- 梯度平方指数平滑值:28GB(即 Model Weight * 2)
	- 模型参数:28GB(即 Model Weight * 2)
- Activation:14.5GB【BatchSize * SequenceLen * HiddenSize * Layer * (34 + 5 * AttentionHead / HiddenSize)】
全参微调至少需要126.5GB;

ZeRO-1【对优化器存储进行切分】
模型参数大小:7B
- Model Weight:14GB
- Gradient:14GB
- Optimizer State:84GB / 4 = 21GB【GPU数量:4】
- Activation:14.5GB
全参微调至少需要63.5GB;

ZeRO-2【对优化器存储、梯度存储数据进行切分】
模型参数大小:7B
- Model Weight:14GB
- Gradient:14GB / 4 = 3.5GB【GPU数量:4】
- Optimizer State:84GB / 4 = 21GB【GPU数量:4】
- Activation:14.5GB
全参微调至少需要53GB;

ZeRO-3【对优化器存储、梯度存储数据、参数w存储数据进行切分】
模型参数大小:7B
- Model Weight:14GB / 4 = 3.5GB【GPU数量:4】
- Gradient:14GB / 4 = 3.5GB【GPU数量:4】
- Optimizer State:84GB / 4 = 21GB【GPU数量:4】
- Activation:14.5GB
全参微调至少需要42.5GB;

价格预算:
两卡 - batch_size:4 - 只能stage3  - 加offload - 20个小时 = 400元
两卡 - batch_size:4 - 只能stage3  - 不加offload - 14个小时 = 280元
四卡 - batch_size:4 - 只能stage3 - 加offload - 6.5个小时 = 260元

建议先使用 7B - Chat版本跑通流程

创建基本目录并移动代码【参考上一个教程】
下载模型【参考上一个教程】
安装依赖【参考上一个教程】
修改训练参数:training_config/baichuan2_config.json.json

{
    "output_dir": "./output/baichuan2-sft-1e5-1125",
    "model_name_or_path": "/hy-tmp/autodl-tmp/artboy/base_model/Baichuan2-7B-Base",
    "deepspeed": "/hy-tmp/autodl-tmp/artboy/finetune/llm_code/deepspeed_config/deepspeed_stage_1_config.json",
    "train_file": "/hy-tmp/autodl-tmp/artboy/finetune/llm_code/data/psychology_data.jsonl",
    "num_train_epochs": 2,
    "per_device_train_batch_size": 12,
    "gradient_accumulation_steps": 4,
    "learning_rate": 1e-5,
    "max_seq_length": 512,
    "logging_steps": 10,
    "save_steps": 100,
    "save_total_limit": 1,
    "lr_scheduler_type": "cosine",
    "warmup_steps": 200,
    "gradient_checkpointing": false,
    "disable_tqdm": false,
    "optim": "adamw_hf",
    "seed": 42,
    "fp16": true,
    "report_to": "tensorboard",
    "dataloader_num_workers": 5,
    "save_strategy": "steps",
    "weight_decay": 0,
    "max_grad_norm": 1.0,
    "remove_unused_columns": false
}

动态调整batch_size:

在这里插入图片描述
启动:【日志命名规则:日期_时间.log

nohup bash train_qwen_qlora.sh > 202400806_1153.log &
tail -f 202400806_1153.log

开启新的窗口,关注显卡使用情况:【动态调整batch_size来合理利用GPU资源】

watch -n -1 nvidia-smi

模型评估

大模型评估主要用于两个方面:

  • 在训练过程中对大模型进行评估,主要用于判断当前训练阶段模型的效果,避免出现由于参数或者数据原因造成的训练过程中模型坍塌的情况,在间隔N个checkpoint时对模型效果进行评估。
  • 在训练完成后对大模型进行评估,主要用于评估训练完成的模型与 baseline 模型的性能差异,以及训练完成的模型与其他大模型的性能差异,用于明确大模型的垂域生成能力和通用能力

我们本次使用 BLEU和MMCU对训练完成的大模型进行评估。

BLEU
  • 介绍:BLEU由BM于2002年提出,用于评估机器翻译任务,原文《BLEU:a Method forAutomatic Evaluation of Machine Translation》,发布于ACL,引用次数 10000+。
  • 核心思想:衡量自动生成的翻译结果 (candidate) 与参考翻译 (reference) 之间的相似度。
  • 指标计算方法:基于n-gram的匹配
    • 统计n-gram: 首先,我们将候选翻译和参考翻译都分割成n个连续的单词。例如,当n取2时,一句话"我爱你"会被分割成"我爱”“爱你”。然后,我们统计每个h-gram在候选翻译中的出现次数以及在参考翻译中的最大出现次数。
    • 精确匹配率 (Precision): 对于每个n-gram,我们计算它在候选翻译中的出现次数与在参考翻译中的最大出现次数的比例。如果候选翻译中的出现次数大于等于参考翻译中的最大出现次数,那么该n-gram的精确匹配率为1,否则为候选翻译中的出现次数除以参考翻译中的最大出现次数。
    • 累权重准确率(Cumulative n-gram Weighted Precision):BLEU采用不同n-gram长度的加权计算。具体来说,对于每个n-gram的精确匹配率,我们将其乘以相应的权重,然后将这些加权值进行累加,得到累加权重准确率。
    • Brevity Penalty (BP):为了惩罚过短的候选翻译并鼓励生成和参考翻译长度相似的结果,我们计算候选翻译长度与参考翻译长度的比例。如果候选翻译长度小于参考翻译长度,则BP为1,否则,BP为e(1 参考翻译长度/候选翻译长度)。这个惩罚项被乘以之前计算的累加权重准确率越长BP越大)
    • BLEU分数:最终,将累加权重准确率乘以Brevity Penalty得到BLEU分数。通常,BLEU分数会乘以100,以便更方便地表示百分比形式。
MMCU
  • 介绍:2023年5月15日,甲骨易AI研究院推出首个中文的大模型评测数据集——“超越”(Massive Multitask Chinese Understanding,简称MMCU),填补了中文大语言模型能力测试缺失的一大空白。
  • MMCU测试集包括医疗、法律、心理学和教育四个领域的测试数据,数据被组织成选择题的形式,可以用于评估模型的通用能力,或者以MMCU作为基础测试集,在此基础上进行进一步的拓展。
  • 地址:https://github.com/Felixgithub2017/MMCU

服务展示

WSGI

介绍:WSGI (Web Server Gateway lnterface) 是一种定义了Web应用程序和Web服务器之间通信协议的Python标准。WSGI 服务器是实现了这个协议的服务器,用于处理来自客户端的HTTP请求并将请求传递给相应的Web应用程序进行处理。

Gunicorn

在这里插入图片描述

  • 介绍:GunicornAGreen Unicorn) 是一个基于Python的Web服务器网关接口 (WSGI) HTTP服务器。它被广泛用于将Python Web应用程序部署到生产环境中。
  • 设计目标:提供简单、稳定和高效的服务。它支持多进程,能够处理并发请求并通过监控和管理工作进程来提高应用程序的性能和可靠性。
    • 多种工作模式: Gunicorn支持不同的工作模式,包括同步 (Sync))、异步 (Async)和线程池(Thread-based)。您可以根据实际需求选择适合的工作模式,以平衡性能和资源消耗。
    • 配置灵活性:Gunicorn提供了一系列的配置选项,可以通过配置文件或命令行参数进行设置您可以自定义监听端口、工作进程数、超时时间、日志级别等各种参数,以满足特定应用程序的需求。
    • 负载均衡支持: Gunicorn允许结合其他工具或服务器 (如Nginx或HAProxy)实现负载衡您可以将多个Gunicorn实例放在负载均衡器后面,以提高系统的可扩展性和容错性。
Bottle
  • 介绍:Bottle 是一个简单而轻量级的 Python Web 框架,专注于快速、简洁地构建Web应用程序和API服务。
  • 特点:
    • 简单易用: Bottle 的设计理念是使Web开发变得简单,它的API 非常直观和简洁。您可以仅使用几行代码就能创建女个功能完备的web 应用程序
    • 轻量级:Bottle 的核心只有一个Python文件,没有外部依赖。这使得它非常轻量目易于部署和维护。您可以将 Bottle 应用程序部署在各种环境中,包括服务器、云平台和单个脚本文件等。
    • 路由和请求处理: Botle 提供了灵活且易于使用的路由系统。您可以使用装饰器来定义路由规则,并根据请求方法(GET、POST等)和路径匹配进行处理。此外,Bottle 还提供了方请求与响应对象,使得处理请求数据和生成响应变得简单。

RAG

介绍

  • 背景:大型语言模型(LLM)存在一些固有的局限性,如“模型幻觉问题”、“时效性问题”和“数据安全问题”。
  • 介绍:RAG 是检索增强生成(Retrieval Augmented Generation )的简称,它为大语言模型 (LLMs) 提供了从数据源检索信息的能力,并以此为基础生成回答。
    在这里插入图片描述
    • 步骤1:问题理解,准确把握用户的意图。
    • 步骤2:知识检索,从知识库中相关的知识检索。【难点,用户提问可能以多种方式表达,而知识库的信息来源可能是多样的,包括PDF、PPT、Neo4j等格式。
    • 步骤3:答案生成,将检索结果与问题。
  • 优点:
    • 提高准确性和相关性。
    • 改善时效性,使模型适应当前事件和知识。
    • 降低生成错误风险,依赖检索系统提供的准确信息。

以下是RAG输出到大型语言模型的典型模板:

你是一个{task}方面的专家,请结合给定的资料,并回答最终的问题。请如实回答,如果问题在资料中找不到答案,请回答不知道。

问题:{question}

资料:
- {information1}
- {information2}
- {information3}

其中,{task}代表任务的领域或主题,{question}是最终要回答的问题,而{information1}、{information2}等则是提供给模型的外部知识库中的具体信息。

分类

参考论文:Retrieval-Augmented Generation for Large Language Models: A Survey

RAG可以根据技术复杂度,分为三种:

  • Naive RAG:Naive RAG是RAG技术的最基本形式,也被称为经典RAG。包括索引、检索、生成三个基本步骤。索引阶段将文档库分割成短的Chunk,并构建向量索引。检索阶段根据问题和Chunks的相似度检索相关文档片段。生成阶段以检索到的上下文为条件,生成问题的回答。
  • Advanced RAG:Advanced RAG在Naive RAG的基础上进行优化和增强。包含额外处理步骤,分别在数据索引、检索前和检索后进行。包括更精细的数据清洗、设计文档结构和添加元数据,以提升文本一致性、准确性和检索效率。在检索前使用问题的重写、路由和扩充等方式对齐问题和文档块之间的语义差异在检索后通过重排序避免“Lost in the Middle”现象,或通过上下文筛选与压缩缩短窗口长度在这里插入图片描述
  • Modular RAG:Modular RAG引入更多具体功能模块,例如查询搜索引擎、融合多个回答等。技术上融合了检索与微调、强化学习等。流程上对RAG模块进行设计和编排,出现多种不同RAG模式。提供更大灵活性,系统可以根据应用需求选择合适的功能模块组合。模块化RAG的引入使得系统更自由、灵活,适应不同场景和需求。

RAG和SFT对比:

特性RAG技术SFT模型微调
知识更新实时更新检索库,适合动态数据,无需频繁重训存储静态信息,更新知识需要重新训练
外部知识高效利用外部资源,适合各类数据库可对齐外部知识,但对动态数据源不够灵活
数据处理数据处理需求低需构建高质量数据集,数据限制可能影响性能
模型定制化专注于信息检索和整合,定制化程度低可定制行为,风格及领域知识
可解释性答案可追溯,解释性高解释性相对低
计算资源需要支持检索的计算资源,维护外部数据源需要训练数据集和微调资源
延迟要求数据检索可能增加延迟微调后的模型反应更快
减少幻觉基于实际数据,幻觉减少通过特定域训练可减少幻觉,但仍然有限
道德和隐私处理外部文本数据时需要考虑隐私和道德问题训练数据的敏感内容可能引发隐私问题

实战 - 搭建 RAG Demo

预备

1.开通免费试用阿里云PAI—DSW

链接:https://free.aliyun.com/?searchKey=PAI

开通PAI-DSW 试用 ,可获得 5000算力时!有效期3个月!

在这里插入图片描述
2.在魔搭社区进行授权

链接:https://www.modelscope.cn/my/mynotebook/authorization

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
创建实例:在这里插入图片描述
这里一定要选择支持资源包抵扣的服务器
在这里插入图片描述
创建实例后打开,能看到这个界面就成功啦!

在这里插入图片描述

索引

新建bge-small-zh-v1.5-download.py,用于向量模型下载

# 向量模型下载
from modelscope import snapshot_download
model_dir = snapshot_download("AI-ModelScope/bge-small-zh-v1.5", cache_dir='.')
/usr/local/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
Downloading: 100%|██████████| 190/190 [00:00<00:00, 328B/s]
Downloading: 100%|██████████| 776/776 [00:00<00:00, 1.48kB/s]
Downloading: 100%|██████████| 124/124 [00:00<00:00, 266B/s]
Downloading: 100%|██████████| 47.0/47.0 [00:00<00:00, 92.2B/s]
Downloading: 100%|██████████| 91.4M/91.4M [00:00<00:00, 120MB/s] 
Downloading: 100%|██████████| 349/349 [00:00<00:00, 715B/s]
Downloading: 100%|██████████| 91.4M/91.4M [00:00<00:00, 123MB/s] 
Downloading: 100%|██████████| 27.5k/27.5k [00:00<00:00, 54.9kB/s]
Downloading: 100%|██████████| 52.0/52.0 [00:00<00:00, 65.0B/s]
Downloading: 100%|██████████| 125/125 [00:00<00:00, 253B/s]
Downloading: 100%|██████████| 429k/429k [00:00<00:00, 703kB/s]
Downloading: 100%|██████████| 367/367 [00:00<00:00, 759B/s]
Downloading: 100%|██████████| 107k/107k [00:00<00:00, 220kB/s]

封装向量模型类 EmbeddingModel:

from typing import List
from transformers import AutoTokenizer, AutoModel
import torch

# 定义向量模型类
class EmbeddingModel:
    """
    用于加载预训练的模型并计算文本的嵌入向量的类。
    """

    def __init__(self, path: str) -> None:
        """
        初始化方法,加载预训练的分词器和模型。

        :param path: 预训练模型的路径。
        """
        # 加载预训练的分词器
        self.tokenizer = AutoTokenizer.from_pretrained(path)

        # 加载预训练的模型,并将其移动到GPU上
        self.model = AutoModel.from_pretrained(path).cuda()
        print(f'Loading EmbeddingModel from {path}.')

	# 为了充分发挥GPU矩阵计算的优势,输入和输出都是一个 List,即多条文本和他们的向量表示。
    def get_embeddings(self, texts: List[str]) -> List[List[float]]:
        """
        计算文本列表的嵌入向量。

        :param texts: 要计算嵌入向量的文本列表。
        :return: 嵌入向量的列表。
        """
        # 使用分词器对文本进行编码
        encoded_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
        
        # 将编码后的输入数据移动到GPU上
        encoded_input = {k: v.cuda() for k, v in encoded_input.items()}

        # 不计算梯度,以提高计算效率
        with torch.no_grad():
            # 将编码后的输入传递给模型,获取模型输出
            model_output = self.model(**encoded_input)
            
            # 获取句子的嵌入向量,这里假设模型输出的第一个元素包含嵌入信息
            sentence_embeddings = model_output[0][:, 0]

        # 归一化嵌入向量
        sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
        
        # 将嵌入向量转换为列表格式并返回
        return sentence_embeddings.tolist()

测试:

print("> Create embedding model...")
embed_model_path = './AI-ModelScope/bge-small-zh-v1___5'
embed_model = EmbeddingModel(embed_model_path)

"""
> Create embedding model...
Loading EmbeddingModel from ./AI-ModelScope/bge-small-zh-v1___5.
"""
检索

编写一个知识库文档knowledge.txt

广州大学(Guangzhou University),简称广大(GU),是由广东省广州市人民政府举办的全日制普通高等学校,实行省市共建、以市为主的办学体制,是国家“111计划”建设高校、广东省和广州市高水平大学重点建设高校。广州大学的办学历史可以追溯到1927年创办的私立广州大学;1951年并入华南联合大学;1983年筹备复办,1984年定名为广州大学;2000年7月,经教育部批准,与广州教育学院(1953年创办)、广州师范学院(1958年创办)、华南建设学院西院(1984年创办)、广州高等师范专科学校(1985年创办)合并组建成立新的广州大学。
郑州机械研究所有限公司(以下简称郑机所)的前身机械科学研究院1956年始建于北京,是原机械工业部直属一类综合研究院所,现隶属于国资委中国机械科学研究总院集团有限公司。郑机所伴随着共和国的成长一路走来,应运而生于首都,碧玉年华献中原。多次搬迁,驻地从北京经漯河再到郑州;数易其名,由机械科学研究院到漯河机械研究所再到郑州机械研究所,现为郑州机械研究所有限公司。1956~1958年应运而生:依据全国人大一届二次会议的提议和第一机械工业部的决策,1956年3月6日,第一机械工业部发文《(56)机技研究第66号》,通知“机械科学实验研究院”(后改名为“机械科学研究院”)在北京成立。1959~1968年首次创业:承担国家重大科研项目与开发任务,以及行业发展规划以及标准制定等工作,如“九大设备”的若干关键技术等。1969~1972年搬迁河南:1969年按照“战备疏散”的要求,机械科学研究院主体迁建河南漯河,成立“漯河机械研究所”;1972年因发展需要,改迁河南郑州,成立郑州机械研究所。1973~1998年二次创业:先后隶属于国家机械工业委员会、机械电子工业部、机械工业部;1981年4月罗干由铸造室主任升任副所长,同年经国务院批准具备硕士学位授予权;1985年“葛洲坝二、三江工程及其水电机组项目”荣获国家科技进步特等奖。1999~2016年发展壮大:1999年转企改制,隶属于国资委中国机械科学研究总院;2008年被河南省首批认定为“高新技术企业”;2011年获批组建新型钎焊材料与技术国家重点实验室;2014年被工信部认定为“国家技术创新示范企业”;历经十多年开发出填补国内外空白的大型齿轮齿条试验装备,完成了对三峡升船机齿条42.2万次应力循环次数的疲劳寿命试验测试;营业收入从几千万发展到近6亿;2017年至今协同发展:2017年经公司制改制,更名为郑州机械研究所有限公司,一以贯之地坚持党对国有企业的领导,充分发挥党委把方向、管大局、保落实的领导作用;一以贯之地建立现代企业制度,持续推进改革改制,努力实现以高质量党建引领郑机所高质量发展。 
非洲野犬,属于食肉目犬科非洲野犬属哺乳动物。 又称四趾猎狗或非洲猎犬; 其腿长身短、体形细长;身上有鲜艳的黑棕色、黄色和白色斑块;吻通常黑色,头部中间有一黑带,颈背有一块浅黄色斑;尾基呈浅黄色,中段呈黑色,末端为白色,因此又有“杂色狼”之称。 非洲野犬分布于非洲东部、中部、南部和西南部一带。 栖息于开阔的热带疏林草原或稠密的森林附近,有时也到高山地区活动。其结群生活,没有固定的地盘,一般在一个较大的范围内逗留时间较长。非洲野犬性情凶猛,以各种羚羊、斑马、啮齿类等为食。奔跑速度仅次于猎; 雌犬妊娠期为69-73天,一窝十只仔,哺乳期持续6-12个星期。 其寿命11年。 非洲野犬正处在灭绝边缘,自然界中仅存两三千只。 非洲野犬被列入《世界自然保护联盟濒危物种红色名录》中,为濒危(EN)保护等级。 ",非洲野犬共有42颗牙齿(具体分布为:i=3/3;c=1/1;p=4/4;m=2/3x2),前臼齿比相对比其他犬科动物要大,因此可以磨碎大量的骨头,这一点很像鬣狗。 主要生活在非洲的干燥草原和半荒漠地带,活跃于草原、稀树草原和开阔的干燥灌木丛,甚至包括撒哈拉沙漠南部一些多山的地带。非洲野犬从来不到密林中活动。 

定义一个向量库索引类 VectorStoreIndex

# 定义向量库索引类
class VectorStoreIndex:
    """
    用于创建向量库索引,计算文本之间的相似度并进行查询的类。
    """

    def __init__(self, document_path: str, embed_model: EmbeddingModel) -> None:
        """
        初始化方法,从文件中加载文档并计算它们的嵌入向量。

        :param document_path: 包含文档的文件路径。
        :param embed_model: 用于生成文档嵌入向量的模型实例。
        """
        self.documents = []  # 存储文档文本的列表
        # 从文件中读取文档并添加到文档列表中
        for line in open(document_path, 'r', encoding='utf-8'):
            line = line.strip()
            self.documents.append(line)

        # 存储嵌入模型的引用
        self.embed_model = embed_model

        # 为所有文档计算嵌入向量
        self.vectors = self.embed_model.get_embeddings(self.documents)

        # 打印加载文档的数量和文件路径
        print(f'Loading {len(self.documents)} documents for {document_path}.')

    def get_similarity(self, vector1: List[float], vector2: List[float]) -> float:
        """
        计算两个向量之间的余弦相似度。

        :param vector1: 第一个向量。
        :param vector2: 第二个向量。
        :return: 两个向量之间的余弦相似度。
        """
        # 计算两个向量的点积
        dot_product = np.dot(vector1, vector2)
        # 计算两个向量的模
        magnitude = np.linalg.norm(vector1) * np.linalg.norm(vector2)
        # 如果模为0,返回0以避免除以0的错误
        if not magnitude:
            return 0
        # 返回归一化的点积作为余弦相似度
        return dot_product / magnitude

    def query(self, question: str, k: int = 1) -> List[str]:
        """
        根据问题查询最相似的文档。

        :param question: 查询的问题文本。
        :param k: 返回最相似文档的数量,默认为1。
        :return: 最相似的文档列表。
        """
        # 为问题文本计算嵌入向量
        question_vector = self.embed_model.get_embeddings([question])[0]
        # 计算问题向量与所有文档向量的相似度
        result = np.array([self.get_similarity(question_vector, vector) for vector in self.vectors])
        # 获取最相似的k个文档的索引
        return np.array(self.documents)[result.argsort()[-k:][::-1]].tolist()

测试:

print("> Create index...")​
doecment_path = './knowledge.txt'​
index = VectorStoreIndex(doecment_path, embed_model)

question = '介绍一下广州大学'print('> Question:', question)
context = index.query(question)print('> Context:', context)

如果知识库很大,需要将知识库切分成多个batch,然后分批次送入向量模型。这里,因为我们的知识库比较小,所以就直接传到了get_embeddings() 函数。

返回结果:

在这里插入图片描述
我们传入用户问题 介绍一下广州大学,可以看到,准确地返回了知识库中的第一条知识。

生成

编写Yuan2-2B-Mars-hf-download.py,下载大模型Yuan2-2B-Mars-hf:

# 源大模型下载​
from modelscope import snapshot_download​
model_dir = snapshot_download('IEITYuan/Yuan2-2B-Mars-hf', cache_dir='.')

定义一个大语言模型类 LLM:

# 导入必要的库
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# 定义大语言模型类
class LLM:
    """
    用于加载和使用Yuan2.0大型语言模型的类。
    """

    def __init__(self, model_path: str) -> None:
        """
        初始化方法,加载预训练的分词器和模型。

        :param model_path: 预训练模型的路径。
        """
        print("Creat tokenizer...")  # 打印创建分词器的信息
        # 加载预训练的分词器,设置不自动添加结束和开始标记
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, add_eos_token=False, add_bos_token=False, eos_token='<eod>')

        # 向分词器添加特殊标记
        self.tokenizer.add_tokens([
            '<sep>', '<pad>', '<mask>', '<predict>', '<FIM_SUFFIX>', '<FIM_PREFIX>', '<FIM_MIDDLE>',
            '<commit_before>', '<commit_msg>', '<commit_after>', '<jupyter_start>', '<jupyter_text>',
            '<jupyter_code>', '<jupyter_output>', '<empty_output>'
        ], special_tokens=True)

        print("Creat model...")  # 打印创建模型的信息
        # 加载预训练的因果语言模型,并将其设置为半精度浮点数以提高计算效率
        self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True).cuda()

        # 打印加载模型的信息
        print(f'Loading Yuan2.0 model from {model_path}.')

    def generate(self, question: str, context: List[str]):
        """
        根据问题和上下文生成回答。

        :param question:  用户提问,是一个str。
        :param context: 检索到的上下文信息,是一个List,默认是[],代表没有使用RAG。
        """
        # 如果提供了上下文,构建提示信息
        if context:
            prompt = f'背景:{" ".join(context)}\n问题:{question}\n请基于背景,回答问题。'
        else:
            prompt = question

        # 在提示信息后添加分隔符
        prompt += "<sep>"
        # 使用分词器将提示信息转换为模型输入的格式
        inputs = self.tokenizer(prompt, return_tensors="pt")["input_ids"].cuda()
        # 使用模型生成文本
        outputs = self.model.generate(inputs, do_sample=False, max_length=1024)
        # 解码生成的文本
        output = self.tokenizer.decode(outputs[0])

        # 打印生成的文本,只显示分隔符之后的部分
        print(output.split("<sep>")[-1])

测试:

print("> Create Yuan2.0 LLM...")​
model_path = './IEITYuan/Yuan2-2B-Mars-hf'​
llm = LLM(model_path)

print('> Without RAG:')​
llm.generate(question, [])print('> With RAG:')​
llm.generate(question, context)

在这里插入图片描述

打包

编写requirements.txt

transformers
torch
numpy

编写安装脚本:pip_install.sh

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt

新建main.py

from common import constants
from generation.llm import LLM
from indexing.embedding import EmbeddingModel
from retrieval.vector import VectorStoreIndex


def main():
    print("> Create embedding model...")
    embed_model_path = constants.EMBED_MODEL_PATH
    embed_model = EmbeddingModel(embed_model_path)

    print("> Create index...")
    document_path = constants.DOCUMENT_PATH
    index = VectorStoreIndex(document_path, embed_model)

    question = '介绍一下广州大学'
    print('> Question:', question)
    context = index.query(question)
    print('> Context:', context)

    print("> Create Yuan2.0 LLM...")
    model_path = constants.MODEL_PATH
    llm = LLM(model_path)
    print('> Without RAG:')
    llm.generate(question, [])
    print('> With RAG:')
    llm.generate(question, context)


if __name__ == '__main__':
    main()

在这里插入图片描述

部署

将代码部署到Github:https://github.com/itxaiohanglover/rag_demo

然后进入终端,导入写好的代码:

在这里插入图片描述
在这里插入图片描述

下载模型:

python setup.py

启动代码:

python main.py

在这里插入图片描述

实战 - AI科研助手

在这里插入图片描述

项目主要包含一个Streamlit开发的客户端,以及一个部署好的浪潮源大模型的服务端。​

  • 客户端接收到用户上传的PDF后,发送到服务端。服务端首先完成PDF内容解析,然后拼接摘要Prompt并输入源大模型,得到模型输出结果后,返回给客户端并展示给用户。​
  • 如果用户接下来进行提问,客户端将用户请求发送到服务端,服务端进行Embedding和Faiss检索,然后将检索到的chunks与用户请求拼接成Prompt并输入到源大模型,得到模型输出结果后,返回给客户端进行结构化,然后展示给用户。
项目结构

在这里插入图片描述
主模块:main.py

# 导入必要的库和模块
from langchain_community.document_loaders import PyPDFLoader  # 用于加载PDF文件的加载器
from common import constants  # 导入常量配置
import streamlit as st  # 导入Streamlit库,用于构建Web界面

from llm.yuan2_llm import Yuan2_LLM  # 导入自定义的大型语言模型类
from langchain_huggingface import HuggingFaceEmbeddings  # 导入HuggingFace嵌入模型
# 导入提示模板类
from prompts.chatbot_template import ChatBot
from prompts.summarizer_template import Summarizer

# 定义模型路径
model_path = constants.MODEL_PATH  # 从常量配置中获取模型路径

# 定义向量模型路径
embedding_model_path = constants.EMBED_MODEL_PATH  # 从常量配置中获取向量模型路径


# 定义一个函数,用于获取llm和embeddings
@st.cache_resource  # 使用Streamlit的缓存装饰器来缓存函数的结果
def get_models():
    llm = Yuan2_LLM(model_path)  # 创建Yuan2_LLM实例

    # 定义模型和编码的参数
    model_kwargs = {'device': 'cuda'}
    encode_kwargs = {'normalize_embeddings': True}  # 设置为True以计算余弦相似度
    embeddings = HuggingFaceEmbeddings(
        model_name=embedding_model_path,  # 向量模型的名称或路径
        model_kwargs=model_kwargs,  # 模型参数
        encode_kwargs=encode_kwargs,  # 编码参数
    )
    return llm, embeddings  # 返回创建的LLM和嵌入模型实例


def main():
    # 创建一个标题
    st.title('💬 Yuan2.0 AI科研助手')  # 设置Streamlit应用的标题

    # 获取llm和embeddings
    llm, embeddings = get_models()  # 调用get_models函数获取模型实例

    # 初始化summarizer
    summarizer = Summarizer(llm)  # 创建Summarizer实例用于生成文本摘要

    # 初始化ChatBot
    chatbot = ChatBot(llm, embeddings)  # 创建ChatBot实例用于回答问题

    # 上传pdf
    uploaded_file = st.file_uploader("Upload your PDF", type='pdf')  # 创建文件上传器,允许用户上传PDF文件

    if uploaded_file:
        # 加载上传PDF的内容
        file_content = uploaded_file.read()  # 读取上传的文件内容

        # 写入临时文件
        temp_file_path = "temp.pdf"  # 定义临时文件路径
        with open(temp_file_path, "wb") as temp_file:
            temp_file.write(file_content)  # 将文件内容写入临时文件

        # 加载临时文件中的内容
        loader = PyPDFLoader(temp_file_path)  # 创建PDF加载器实例
        docs = loader.load()  # 使用加载器加载文档内容

        st.chat_message("assistant").write(f"正在生成论文概括,请稍候...")  # 在Streamlit界面上显示消息

        # 生成概括
        summary = summarizer.summarize(docs)  # 调用summarizer的summarize方法生成摘要

        # 在聊天界面上显示模型的输出
        st.chat_message("assistant").write(summary)  # 显示生成的摘要

        # 接收用户问题
        if query := st.text_input("Ask questions about your PDF file"):  # 创建文本输入框,允许用户输入问题
            # 检索 + 生成回复
            chunks, response = chatbot.run(docs, query)  # 调用chatbot的run方法进行检索和生成回答

            # 在聊天界面上显示模型的输出
            st.chat_message("assistant").write(f"正在检索相关信息,请稍候...")  # 显示检索信息的消息
            st.chat_message("assistant").write(chunks)  # 显示检索到的文档片段

            st.chat_message("assistant").write(f"正在生成回复,请稍候...")  # 显示生成回答的消息
            st.chat_message("assistant").write(response)  # 显示生成的回答


if __name__ == '__main__':
    main()  # 如果是主程序,则调用main函数运行应用

提示词模块:

chatbot_template.py

from langchain_core.prompts import PromptTemplate
from langchain_text_splitters import RecursiveCharacterTextSplitter  # 导入文本分割器
from langchain.chains.question_answering import load_qa_chain  # 导入load_qa_chain,用于加载问答链
from langchain_community.vectorstores import FAISS
# 定义聊天机器人模板
chatbot_template = '''
假设你是一个AI科研助手,请基于背景,简要回答问题。

背景:
{context}

问题:
{question}
'''.strip()


# 定义ChatBot类
class ChatBot:
    """
    ChatBot类用于处理用户提问,并基于文档内容生成回答。
    """

    def __init__(self, llm, embeddings):
        self.prompt = PromptTemplate(
            input_variables=["text"],
            template=chatbot_template
        )  # 定义聊天机器人提示模板
        self.chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=self.prompt)  # 加载问答链
        self.embeddings = embeddings  # 嵌入模型,用于文档向量化

        # 加载文本分割器,用于将长文本切分成小块
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=450,
            chunk_overlap=10,
            length_function=len
        )

    def run(self, docs, query):
        """
        处理用户提问,生成回答。

        :param docs: 文档列表,每个文档包含page_content属性
        :param query: 用户的提问
        :return: 检索到的文档片段和生成的回答
        """
        # 读取所有文档内容
        text = ''.join([doc.page_content for doc in docs])

        # 使用文本分割器切分成chunks
        all_chunks = self.text_splitter.split_text(text=text)

        # 将文本chunks转换为向量并存储
        VectorStore = FAISS.from_texts(all_chunks, embedding=self.embeddings)

        # 检索与提问最相似的chunks
        chunks = VectorStore.similarity_search(query=query, k=1)

        # 使用问答链生成回答
        response = self.chain.run(input_documents=chunks, question=query)

        return chunks, response  # 返回检索到的文档片段和生成的回答

summarizer_template.py

# 导入必要的库和模块
from langchain.chains.llm import LLMChain  # 导入LLMChain,用于构建基于LLM的生成链
from langchain_core.prompts import PromptTemplate  # 导入PromptTemplate,用于构建提示模板


# 定义摘要模板
summarizer_template = """
假设你是一个AI科研助手,请用一段话概括下面文章的主要内容,200字左右。

{text}
"""

# 定义Summarizer类
class Summarizer:
    """
    Summarizer类用于将长文本内容压缩成简短的摘要。
    """

    def __init__(self, llm):
        self.llm = llm  # LLM实例,用于生成文本
        self.prompt = PromptTemplate(
            input_variables=["text"],
            template=summarizer_template
        )  # 定义摘要提示模板
        self.chain = LLMChain(llm=self.llm, prompt=self.prompt)  # 创建LLMChain实例,用于生成摘要

    def summarize(self, docs):
        """
        从文档中生成摘要。

        :param docs: 文档列表,其中每个文档包含page_content属性
        :return: 生成的摘要文本
        """
        # 从第一页中获取摘要内容,假设摘要位于'ABSTRACT'和'KEY WORDS'之间
        content = docs[0].page_content.split('ABSTRACT')[1].split('KEY WORDS')[0]
        summary = self.chain.run(content)  # 使用LLMChain生成摘要
        return summary  # 返回生成的摘要

源大模型模块:
yuan2_llm.py

# 导入必要的库
from typing import List, Optional, Any
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# 导入常量配置
from common import constants

# 导入LLM基类
from langchain.llms.base import LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun

# 定义模型路径
model_path = constants.MODEL_PATH

# 定义模型数据类型
torch_dtype = torch.bfloat16


# 定义源大模型类
class Yuan2_LLM(LLM):
    """
    YUAN2_LLM类用于加载和使用预训练的大型语言模型。
    它继承自langchain的LLM基类,并实现了自己的_call方法来生成文本。
    """

    # 类变量,用于存储分词器和模型实例
    tokenizer: AutoTokenizer = None
    model: AutoModelForCausalLM = None

    def __init__(self, model_path: str):
        super().__init__()

        # 加载预训练的分词器和模型
        print("Creating tokenizer...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, add_eos_token=False, add_bos_token=False,
                                                       eos_token='<eod>')
        # 添加特殊标记
        self.tokenizer.add_tokens(
            ['<sep>', '<pad>', '<mask>', '<predict>', '<FIM_SUFFIX>', '<FIM_PREFIX>', '<FIM_MIDDLE>', '<commit_before>',
             '<commit_msg>', '<commit_after>', '<jupyter_start>', '<jupyter_text>', '<jupyter_code>',
             '<jupyter_output>', '<empty_output>'], special_tokens=True)

        print("Creating model...")
        self.model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype,
                                                          trust_remote_code=True).cuda()

    def _call(
            self,
            prompt: str,
            stop: Optional[List[str]] = None,
            run_manager: Optional[CallbackManagerForLLMRun] = None,
            **kwargs: Any,
    ) -> str:
        """
        生成文本的方法,根据输入的prompt生成响应。

        :param prompt: 输入的文本提示
        :param stop: 停止生成的标记列表
        :param run_manager: 运行管理器,用于监控和管理生成过程
        :param kwargs: 其他可选参数
        :return: 生成的文本响应
        """
        prompt = prompt.strip()
        prompt += "<sep>"
        inputs = self.tokenizer(prompt, return_tensors="pt")["input_ids"].cuda()
        outputs = self.model.generate(inputs, do_sample=False, max_length=4096)
        output = self.tokenizer.decode(outputs[0])
        response = output.split("<sep>")[-1].split("<eod>")[0]

        return response

    @property
    def _llm_type(self) -> str:
        """
        返回模型的类型标识。

        :return: 模型类型字符串
        """
        return "Yuan2_LLM"
运行效果
bash run.sh

启动脚本:run.sh

streamlit run main.py --server.address 127.0.0.1 --server.port 6006

论文概括:
在这里插入图片描述

What kind of attention architecture is LFA?

在这里插入图片描述

Mobile Agent

介绍

在这里插入图片描述

  • 背景:现如今,我们的各种办公都是围绕电脑手机完成的,我们经常会遇到一些重复性的脑力劳动,此时我就会想着如何把这些重复性的步骤录入到一个脚本中,从而解放双手。而在不同应用程序中,开发脚本的难度也各有不同,一般其难度这么排序。
    • 最简单:直接开放API接口,可以直接写脚本调用,如飞书、企业微信、钉钉
    • 略微难点:有网页版程序,可以用一些“爬虫”的方法为己所用
    • 难:只有GUI的电脑程序,要么模拟键鼠动作写脚本,要么进行逆向。
    • 最难:只有手机APP,如闲鱼……

然而当今的中国互联网,各大厂商为了圈流量,不约而同的把用户往自家APP上赶,哪怕它们同时有网页端可用,但相比APP,其功能也残缺不全。互联网的围墙越建越高,以后谁要是想打通这堵围墙,整合各种APP的信息,比如:

在这里插入图片描述

  • 原理:Mobile-Agent是一种可用于操作手机的大模型智能体系统,可以实现一句话自动操作手机,是纯视觉方案,不受操作系统限制,能够通过视觉感知工具在移动应用的前端界面中准确识别和定位视觉和文本元素。其原理简单的说,就是把手机截图发给Agent,然后Agent进行思考,决定下一步点哪里,输入什么文本,执行后再截图发给Agent接着决策,直到完成任务。

Action Space(可执行操作)【7种】:

  • 点击文本
  • 点击图表
  • 打字
  • 上划&下划
  • 返回上一页面
  • 返回桌面
  • 结束

Mobile-Agent-v1:不断的执行规划和动作

  • 论文:Mobile-Agent: Autonomous Multi-Modal Mobile Device Agent with Visual Perception,202401
  • 问题:原始的多模态大模型是没有生成精确坐标的grounding能力的,即它即使知道下一步要点这个按钮,但没法报出其xy坐标。
  • 解决方案:
    • 使用OCR工具提出文本和对应的坐标,让Agent点击文字来代替定位。
    • 用图像分割工具提出所有图标及其坐标,让Agent选择图标来代替定位。【图像分割工具提出的所有图标及其坐标(用红点标出)】在这里插入图片描述

Mobile-Agent-v2:引入了多智能体

  • 论文:Mobile-Agent-v2: Mobile Device Operation Assistant with Effective Navigation via Multi-Agent Collaboration,202406
  • 问题:原本的单智能体方案有个比较麻烦的问题,即长序列问题,哪怕只执行了7步,其产生的输入Token数量也高达12k,过长的Token序列对模型性能的影响很大,影响的不止是其行为的正确率,还有模型的推理性能。
  • 解决方案:
    • 3个Agent各司其职:
      • 规划:Planning Agent
      • 执行:Decision Agent
      • 反馈:Reflection Agent

在这里插入图片描述

实验效果:

  • 图表说明:
    • SR:完成全部操作的的成功率。
    • CR:所有步数中正确操作的百分比。
    • DA:决策的准确率。
    • RA:反思的准确率。
    • Know.:在prompt中告诉了模型额外的知识,如操作指南。
    • Basic Instruction:简单的指令。
    • Advance Instruction:复杂的指令。
  • 非英文场景(其实就是中文场景)的性能:在这里插入图片描述
  • 英文场景的性能:在这里插入图片描述
  • 简单指令和复杂指令的example:在这里插入图片描述
  • 操作轮数对Mobile-Agent行为正确率的影响:单智能体的Mobile-Agent-v1在操作轮数多了之后,有效操作明显变少,而采用了多智能体的Mobile-Agent-v2则不受影响。在这里插入图片描述
  • 消融实验:没有一个Agent是多余的。在这里插入图片描述
  • 比对了不同多模态模型的性能:GPT-4V遥遥领先,第二名的准确率连它一半都没有。在这里插入图片描述

阿里云百炼平台

链接:https://bailian.console.aliyun.com/

开通 阿里云百炼平台 ,获得 qwen-vl-plus 的限时免费使用额度100w Token,有效期为30天。

1.登录账号,确认开通
2.查看免费赠送额度

在这里插入图片描述
3.创建API-KEY(记得保存)
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
最好找个备忘录保存一下

安装 Android Studio

官网:https://developer.android.google.cn/studio?hl=zh-cn

Android Studio 默认会把你的SDK下载放在C盘,到时候你的C盘就炸了,谷歌太坏了,很多新手都会直接下一步,这里你需要选择Custom,切记切记,然后点击 Next 下一步。

在这里插入图片描述
勾选SDK,以及SDK的安装路径

在这里插入图片描述
给Android模拟器进行内存设置,推荐多少G就选择多少G,这里直接Next。

在这里插入图片描述
左侧的三项,每一项选中之后点击Accpet,接受协议,三项都通过之后,你会看到Finish按钮亮起,点击Finish。

在这里插入图片描述

新建一个虚拟手机

1.New Project【新建一个项目】

在这里插入图片描述
2.选择Empty Activity

在这里插入图片描述
3.配置项目

在这里插入图片描述
4.等待安装依赖

在这里插入图片描述
5.创建虚拟手机

在这里插入图片描述
选择Pixel 8后,一直next

在这里插入图片描述
6.设置取消勾选

在这里插入图片描述
7.启动虚拟机

在这里插入图片描述
8.将谷歌日历移动到桌面

  • 双指从下往上拉模拟手机使用,打开应用汇总界面​
  • 长按Calendar应用将其拖拽至手机桌面

在这里插入图片描述

运行一个Demo项目

技术栈介绍

  • 图像处理与识别:
    • OCR(光学字符识别): 用于从截图中提取文本内容。使用了 damo/cv_resnet18_ocr-detection-line-level_damo 和 damo/cv_convnextTiny_ocr-recognition-document_damo 两个模型。
    • 图标检测: 使用 GroundingDINO 模型来检测屏幕上的图标。
  • 自然语言处理:
    • Qwen-VL 系列模型: 用于自然语言理解和生成,模型的选择可以是本地模型(如 qwen-vl-chat)或通过API访问(如 qwen-vl-plus)。
    • Prompt Engineering: 使用prompt来指导模型生成适合的响应,包括操作指令生成、反思生成、记忆生成和操作流程规划。
  • 设备控制:
    • 使用 get_screenshot、tap、slide 等函数通过ADB接口操作Android设备,如截图、点击、滑动、输入文本等。

拉取代码

git lfs install​
git clone https://www.modelscope.cn/datasets/Datawhale/MobileAgent_V2_Demo_qwenVL.git​
cd MobileAgent_V2_Demo_qwenVL

创建环境

conda create -n moblieagent python=3.9.19​
conda activate moblieagent

使用Pycharm打开项目文件,安装对应依赖

在这里插入图片描述

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple  -r win_requirements.txt

在这里插入图片描述
修改run.py中的adb_path 和 qwen_api 变量

adb_path = "E:/Android/Sdk/platform-tools/adb.exe"
# windows的路径未修改过应为"C:/Users/{your username}/AppData/Local/Android/Sdk/platform-tools/adb"

qwen_api = "上面步骤中申请的API-KEY"

在这里插入图片描述

运行run.py

在这里插入图片描述
如果遇到cudart64_110.dll not found等dll文件找不到,打开:🔗https://www.dllme.com,搜索对应的dll文件复制到C:\Windows\System32下。

流程分析

  1. 初始化
  • 配置设定: 设置ADB路径,指令内容,模型类型,API密钥等基本配置。
  • 模型加载: 根据配置加载OCR模型、图标检测模型、以及Qwen-VL模型(可以是本地或API调用)。
  1. 获取屏幕信息
  • 截图: 使用ADB获取当前手机屏幕截图并保存。
  • OCR识别: 通过OCR模型检测截图中的文本块及其坐标位置。
  • 图标检测: 通过GroundingDINO模型检测截图中的图标位置,并将图标裁剪出来进行进一步识别。
  • 结果合并: 将文本块和图标识别结果整理成统一格式的 perception_infos 列表,每个元素包含检测到的内容(文本或图标描述)及其屏幕坐标。
  1. 处理用户指令
  • 生成操作指令: 通过调用Qwen-VL模型,根据 perception_infos 生成操作指令(如点击、滑动、输入等)。
  • 执行操作: 根据生成的操作指令通过ADB接口对手机进行相应的操作。
  1. 反思与记忆
  • 反思: 项目支持通过对比前后两次截图的差异,生成反思Prompt,分析当前操作是否正确或需要修正。
  • 记忆: 项目可以将重要的信息保存到记忆中,以便后续操作中更好地理解和处理。
  1. 迭代循环
  • 重复操作: 以上过程在一个循环中不断重复,直到完成所有指令或满足停止条件。

代码解析

1.环境设置与初始化

  • 功能:设置ADB路径、用户指令、选择模型和API的调用方式等配置。
  • 思路:在开始前,项目通过设置 ADB 路径、用户指令、API调用方式以及模型选择来初始化项目运行的基础环境。
  • 代码实现:
# 路径应为"/Users/用户名/Library/Android/sdk/platform-tools/adb"
adb_path = "E:\\Android\\Sdk\\platform-tools\\adb"

# Your instruction
instruction = "Read the Screen, tell me what day it is today. Then open Play Store."

# Choose between "api" and "local". api: use the qwen api. local: use the local qwen checkpoint
caption_call_method = "api"

# Choose between "qwen-vl-plus" and "qwen-vl-max" if use api method. Choose between "qwen-vl-chat" and "qwen-vl-chat-int4" if use local method.
caption_model = "qwen-vl-plus"

# If you choose the api caption call method, input your Qwen api here
qwen_api = "XXXXXXX"

# You can add operational knowledge to help Agent operate more accurately.
add_info = "If you want to tap an icon of an app, use the action \"Open app\". If you want to exit an app, use the action \"Home\""

# Reflection Setting: If you want to improve the operating speed, you can disable the reflection agent. This may reduce the success rate.
reflection_switch = True

# Memory Setting: If you want to improve the operating speed, you can disable the memory unit. This may reduce the success rate.
memory_switch = True

2.聊天历史初始化

  • 功能:初始化不同对话历史(如操作历史、反思历史、记忆历史)用于后续交互。
  • 思路:不同的聊天初始化函数用于分别构建操作对话历史、反思对话历史、记忆对话历史等,这样在不同阶段可以复用这些历史对话记录来生成决策。
  • 代码实现:
def init_action_chat():
    operation_history = []
    sysetm_prompt = "You are a helpful AI mobile phone operating assistant. You need to help me operate the phone to complete the user\'s instruction."
    operation_history.append({'role': 'system','content': [{'text': sysetm_prompt}]})
    return operation_history

3.图像处理与信息提取

  • 功能:截取手机屏幕、进行OCR识别、图标检测、坐标处理等。
  • 思路:该模块负责从手机截图中提取有用的信息,包括文本和图标,并将这些信息转化为后续操作的输入。
  • 代码实现:
def get_perception_infos(adb_path, screenshot_file):
    # 使用adb命令获取设备的屏幕截图
    get_screenshot(adb_path)
    
    # 打开截图文件并获取其宽度和高度
    width, height = Image.open(screenshot_file).size
    
    # 使用OCR技术识别截图中的文本及其坐标
    text, coordinates = ocr(screenshot_file, ocr_detection, ocr_recognition)
    # 合并文本块,以便于处理和显示
    text, coordinates = merge_text_blocks(text, coordinates)
    
    # 计算文本坐标的中心点
    center_list = [[(coordinate[0]+coordinate[2])/2, (coordinate[1]+coordinate[3])/2] for coordinate in coordinates]
    # 在截图上绘制文本坐标的中心点
    draw_coordinates_on_image(screenshot_file, center_list)
    
    # 初始化一个空列表来存储感知信息
    perception_infos = []
    # 遍历文本坐标,为每个文本创建一个感知信息字典
    for i in range(len(coordinates)):
        perception_info = {"text": "text: " + text[i], "coordinates": coordinates[i]}
        perception_infos.append(perception_info)
    
    # 使用目标检测模型识别截图中的图标
    coordinates = det(screenshot_file, "icon", groundingdino_model)
    
    # 遍历图标坐标,为每个图标创建一个感知信息字典
    for i in range(len(coordinates)):
        perception_info = {"text": "icon", "coordinates": coordinates[i]}
        perception_infos.append(perception_info)
    
    # 初始化空列表来存储图标的坐标和ID
    image_box = []
    image_id = []
    # 遍历感知信息,提取图标的坐标和ID
    for i in range(len(perception_infos)):
        if perception_infos[i]['text'] == 'icon':
            image_box.append(perception_infos[i]['coordinates'])
            image_id.append(i)
    
    # 对每个图标坐标进行裁剪并保存到临时文件夹
    for i in range(len(image_box)):
        crop(screenshot_file, image_box[i], image_id[i])
    # 获取临时文件夹中的所有文件
    images = get_all_files_in_folder(temp_file)
    # 如果有文件,则对文件名进行排序
    if len(images) > 0:
        images = sorted(images, key=lambda x: int(x.split('/')[-1].split('.')[0]))
        image_id = [int(image.split('/')[-1].split('.')[0]) for image in images]
        icon_map = {}
        # 定义一个提示,用于描述图标的形状和颜色
        prompt = 'This image is an icon from a phone screen. Please briefly describe the shape and color of this icon in one sentence.'
        # 根据配置选择本地或API方式生成图标描述
        if caption_call_method == "local":
            for i in range(len(images)):
                image_path = os.path.join(temp_file, images[i])
                icon_width, icon_height = Image.open(image_path).size
                # 过滤掉过大的图标
                if icon_height > 0.8 * height or icon_width * icon_height > 0.2 * width * height:
                    des = "None"
                else:
                    des = generate_local(tokenizer, model, image_path, prompt)
                icon_map[i+1] = des
        else:
            for i in range(len(images)):
                images[i] = os.path.join(temp_file, images[i])
            icon_map = generate_api(images, prompt)
        # 更新感知信息中的图标描述
        for i, j in zip(image_id, range(1, len(image_id)+1)):
            if icon_map.get(j):
                perception_infos[i]['text'] = "icon: " + icon_map[j]
    # 更新感知信息中的坐标为坐标中心
    for i in range(len(perception_infos)):
        perception_infos[i]['coordinates'] = [int((perception_infos[i]['coordinates'][0]+perception_infos[i]['coordinates'][2])/2), int((perception_infos[i]['coordinates'][1]+perception_infos[i]['coordinates'][3])/2)]
        
    # 返回感知信息、截图宽度和高度
    return perception_infos, width, height
  1. 深度学习模型加载与推理
  • 功能:加载和初始化所需的深度学习模型,处理用户的指令。
  • 思路:根据用户选择,项目会加载本地或API提供的模型来进行图像描述、文本识别、图标检测等任务。通过选择不同模型和API,可以适应不同的应用场景和硬件环境。
  • 代码实现:
device = "cpu"
torch.manual_seed(1234)
if caption_call_method == "local":
    # Load local models...
elif caption_call_method == "api":
    # Use API for models...

5.操作与执行

  • 功能:根据模型输出的操作指令,执行相应的手机操作(点击、滑动、返回等)。
  • 思路:这一部分是项目的核心逻辑,它根据分析得到的操作指令执行相应的手机操作,来完成用户的任务指令。
  • 代码实现:
if "Open app" in action:
    # Open a specific app...
elif "Tap" in action:
    # Tap on a specific coordinate...
elif "Swipe" in action:
    # Swipe from one coordinate to another...
elif "Type" in action:
    # Type text...
elif "Back" in action:
    back(adb_path)
elif "Home" in action:
    home(adb_path)
elif "Stop" in action:
    break

6.反思与记忆模块

  • 功能:通过反思上一次的操作结果来调整下一步操作的策略,并将有价值的信息存储在记忆中。
  • 思路:通过反思模块,系统会基于之前的操作结果来判断是否需要调整策略,并将重要的信息存储到内存模块中,以便在后续操作中参考。
  • 代码实现:
if reflection_switch:
    # 如果启用了反射性对话开关,则获取反射性对话的提示
    prompt_reflect = get_reflect_prompt(...)
    # 初始化反射性对话
    chat_reflect = init_reflect_chat()
    # 将用户对两张截图的反射性对话添加到聊天历史中
    chat_reflect = add_response_two_image("user", prompt_reflect, chat_reflect, [last_screenshot_file, screenshot_file])

    # 调用本地文件处理函数,传入操作、API密钥和模型名称
    output_reflect = call_with_local_file(chat_action, api_key=qwen_api, model='qwen-vl-plus')
    # 从输出中提取反射性对话的回答部分
    reflect = output_reflect.split("### Answer ###")[-1].replace("\n", " ").strip()
    # 将系统的回答添加到聊天历史中
    chat_reflect = add_response("system", output_reflect, chat_reflect)
    
    # 如果反射性对话的回答中包含字母'A',则将当前的思考、总结和行动添加到历史记录中
    if 'A' in reflect:
        thought_history.append(thought)
        summary_history.append(summary)
        action_history.append(action)
    # 其他条件...

7.主循环与终止条件

  • 功能:主循环执行多轮操作,并根据一定条件终止循环。
  • 思路:项目在一个循环中进行,直到任务完成或达到终止条件。每次循环都会根据新的屏幕截图和用户指令更新操作,并在适当的时候进行反思和策略调整。
  • 代码实现:
while True:
    iter += 1
    # First iteration...
    # Action decision...
    # Memory update...
    # Reflection...
    if "Stop" in action:
        break
    time.sleep(5)

8.总结功能

  • 功能:对项目进行总结,提取核心内容,确保项目达成目标。
  • 思路:这一部分通过对完成任务的总结,验证项目的执行效果,确保达到用户的预期目标。
  • 代码实现:
completed_requirements = output_planning.split("### Completed contents ###")[-1].replace("\n", " ").strip()

api模块:

import base64
import requests

# 定义一个函数,用于将图像文件编码为Base64字符串
def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        # 读取图像文件内容,并使用base64编码
        return base64.b64encode(image_file.read()).decode('utf-8')

# 定义一个函数,用于发送聊天数据到API并获取推理结果
def inference_chat(chat, model, api_url, token):    
    # 设置请求头,包括内容类型和授权令牌
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {token}"
    }

    # 准备发送到API的数据,包括模型名称、消息列表、最大令牌数、温度和种子
    data = {
        "model": model,
        "messages": [],
        "max_tokens": 2048,
        'temperature': 0.0,
        "seed": 1234
    }

    # 遍历聊天数据,将每条消息添加到数据的messages列表中
    for role, content in chat:
        data["messages"].append({"role": role, "content": content})

    # 使用循环来处理API请求,直到成功获取响应或发生异常
    while True:
        try:
            # 发送POST请求到API,传递数据和请求头
            res = requests.post(api_url, headers=headers, json=data)
            # 解析响应数据,获取聊天推理结果
            res_json = res.json()
            res_content = res_json['choices'][0]['message']['content']
        except:
            # 如果请求过程中发生异常,打印网络错误信息
            print("Network Error:")
            try:
                # 尝试打印详细的错误信息
                print(res.json())
            except:
                # 如果无法获取错误信息,打印请求失败信息
                print("Request Failed")
        else:
            # 如果请求成功,跳出循环
            break
    
    # 返回推理结果
    return res_content
  • encode_image 函数:将指定路径的图像文件读取并编码为Base64字符串,这通常用于将图像数据作为字符串发送到API。
  • inference_chat 函数:构建一个包含聊天消息、模型名称和其他参数的数据字典,然后发送POST请求到指定的API URL。这个函数使用循环来确保即使在网络错误或其他异常情况下也能持续尝试发送请求,直到成功或明确失败。函数返回API返回的推理结果。

聊天模块:

import copy
from MobileAgent.api import encode_image

# 初始化操作聊天历史记录的函数
def init_action_chat():
    operation_history = []
    # 系统提示信息,用于指导AI助手的行为
    sysetm_prompt = "You are a helpful AI mobile phone operating assistant. You need to help me operate the phone to complete the user's instruction."
    # 将系统提示信息作为第一条消息添加到聊天历史记录中
    operation_history.append(["system", [{"type": "text", "text": sysetm_prompt}]])
    return operation_history

# 初始化反射聊天历史记录的函数
def init_reflect_chat():
    operation_history = []
    # 系统提示信息,用于指导AI助手的行为
    sysetm_prompt = "You are a helpful AI mobile phone operating assistant."
    # 将系统提示信息作为第一条消息添加到聊天历史记录中
    operation_history.append(["system", [{"type": "text", "text": sysetm_prompt}]])
    return operation_history

# 初始化记忆聊天历史记录的函数
def init_memory_chat():
    operation_history = []
    # 系统提示信息,用于指导AI助手的行为
    sysetm_prompt = "You are a helpful AI mobile phone operating assistant."
    # 将系统提示信息作为第一条消息添加到聊天历史记录中
    operation_history.append(["system", [{"type": "text", "text": sysetm_prompt}]])
    return operation_history

# 向聊天历史记录中添加响应的函数
def add_response(role, prompt, chat_history, image=None):
    # 深拷贝聊天历史记录,以避免修改原始数据
    new_chat_history = copy.deepcopy(chat_history)
    # 如果提供了图像,则将其编码为Base64并构建相应的内容
    if image:
        base64_image = encode_image(image)
        content = [
            {
                "type": "text", 
                "text": prompt
            },
            {
                "type": "image_url", 
                "image_url": {
                    "url": f"data:image/jpeg;base64,{base64_image}"
                }
            },
        ]
    else:
        # 如果没有提供图像,则只构建文本内容
        content = [
            {
            "type": "text", 
            "text": prompt
            },
        ]
    # 将新消息添加到聊天历史记录中
    new_chat_history.append([role, content])
    return new_chat_history

# 向聊天历史记录中添加包含两张图像的响应的函数
def add_response_two_image(role, prompt, chat_history, image):
    # 深拷贝聊天历史记录,以避免修改原始数据
    new_chat_history = copy.deepcopy(chat_history)
    # 对两张图像进行Base64编码
    base64_image1 = encode_image(image[0])
    base64_image2 = encode_image(image[1])
    # 构建包含文本和两张图像的内容
    content = [
        {
            "type": "text", 
            "text": prompt
        },
        {
            "type": "image_url", 
            "image_url": {
                "url": f"data:image/jpeg;base64,{base64_image1}"
            }
        },
        {
            "type": "image_url", 
            "image_url": {
                "url": f"data:image/jpeg;base64,{base64_image2}"
            }
        },
    ]
    # 将新消息添加到聊天历史记录中
    new_chat_history.append([role, content])
    return new_chat_history

# 打印聊天历史记录状态的函数
def print_status(chat_history):
    # 打印分隔线
    print("*"*100)
    # 遍历聊天历史记录,打印每条消息的角色和内容
    for chat in chat_history:
        print("role:", chat[0])
        print(chat[1][0]["text"], end="")
        # 如果消息包含图像,则打印图像占位符
        print("<image>"*(len(chat[1])-1) + "\n")
    # 打印分隔线
    print("*"*100)
  • init_action_chat、init_reflect_chat、init_memory_chat 函数:分别初始化不同类型聊天(操作、反射、记忆)的历史记录,每条记录都以系统提示信息开始。
  • add_response 函数:向聊天历史记录中添加一条新消息,该消息可以包含文本和图像。如果提供了图像,它会使用 encode_image 函数将图像编码为Base64字符串,并构建适当的内容结构。
  • add_response_two_image 函数:类似于 add_response,但允许同时添加两张图像。
  • print_status 函数:打印聊天历史记录的状态,显示每条消息的角色和内容,如果消息包含图像,则打印图像占位符。

控制模块:

import os
import time
import subprocess
from PIL import Image

# 获取Android设备的屏幕截图
def get_screenshot(adb_path):
    # 构建ADB命令以删除旧的截图文件
    command = adb_path + " shell rm /sdcard/screenshot.png"
    # 执行命令,capture_output=True表示捕获输出,text=True表示输出为文本,shell=True表示允许shell命令
    subprocess.run(command, capture_output=True, text=True, shell=True)
    # 等待0.5秒,确保命令执行完成
    time.sleep(0.5)
    # 构建ADB命令以捕获新的屏幕截图
    command = adb_path + " shell screencap -p /sdcard/screenshot.png"
    subprocess.run(command, capture_output=True, text=True, shell=True)
    # 再次等待0.5秒
    time.sleep(0.5)
    # 构建ADB命令以将截图文件从设备拉取到本地
    command = adb_path + " pull /sdcard/screenshot.png ./screenshot"
    subprocess.run(command, capture_output=True, text=True, shell=True)
    # 定义截图文件的本地路径
    image_path = "./screenshot/screenshot.png"
    save_path = "./screenshot/screenshot.jpg"
    # 使用PIL库打开图像文件
    image = Image.open(image_path)
    # 将图像转换为RGB格式并保存为JPEG文件
    image.convert("RGB").save(save_path, "JPEG")
    # 删除临时的PNG截图文件
    os.remove(image_path)

# 模拟在Android设备上进行点击操作
def tap(adb_path, x, y):
    # 构建ADB命令以模拟点击
    command = adb_path + f" shell input tap {x} {y}"
    # 执行命令
    subprocess.run(command, capture_output=True, text=True, shell=True)

# 模拟在Android设备上输入文本
def type(adb_path, text):
    # 将文本中的换行符替换为下划线,以避免输入错误
    text = text.replace("\\n", "_").replace("\n", "_")
    for char in text:
        # 根据字符类型构建并执行不同的ADB命令
        if char == ' ':
            command = adb_path + f" shell input text %s"
            subprocess.run(command, capture_output=True, text=True, shell=True)
        elif char == '_':
            command = adb_path + f" shell input keyevent 66"
            subprocess.run(command, capture_output=True, text=True, shell=True)
        elif 'a' <= char <= 'z' or 'A' <= char <= 'Z' or char.isdigit():
            command = adb_path + f" shell input text {char}"
            subprocess.run(command, capture_output=True, text=True, shell=True)
        elif char in '-.,!?@\'°/:;()':
            command = adb_path + f" shell input text \"{char}\""
            subprocess.run(command, capture_output=True, text=True, shell=True)
        else:
            command = adb_path + f" shell am broadcast -a ADB_INPUT_TEXT --es msg \"{char}\""
            subprocess.run(command, capture_output=True, text=True, shell=True)

# 模拟在Android设备上进行滑动操作
def slide(adb_path, x1, y1, x2, y2):
    # 构建ADB命令以模拟滑动
    command = adb_path + f" shell input swipe {x1} {y1} {x2} {y2} 500"
    # 执行命令
    subprocess.run(command, capture_output=True, text=True, shell=True)

# 模拟在Android设备上执行返回操作
def back(adb_path):
    # 构建ADB命令以模拟返回键
    command = adb_path + f" shell input keyevent 4"
    # 执行命令
    subprocess.run(command, capture_output=True, text=True, shell=True)

# 模拟在Android设备上执行回到主屏幕的操作
def home(adb_path):
    # 构建ADB命令以模拟回到主屏幕
    command = adb_path + f" shell am start -a android.intent.action.MAIN -c android.intent.category.HOME"
    # 执行命令
    subprocess.run(command, capture_output=True, text=True, shell=True)
  • get_screenshot 函数:通过ADB命令获取Android设备的屏幕截图,并将其保存为JPEG格式的文件。
  • tap 函数:通过ADB命令模拟在指定坐标的点击操作。
  • type 函数:通过ADB命令模拟在Android设备上输入文本。
  • slide 函数:通过ADB命令模拟在Android设备上进行滑动操作。
  • back 函数:通过ADB命令模拟按下Android设备的返回键。
  • home 函数:通过ADB命令模拟回到Android设备的主屏幕。

图像处理模块(特征提取和相似度计算):

import math
import cv2
import numpy as np
from PIL import Image, ImageDraw
import clip
import torch

# 裁剪图像的函数
def crop_image(img, position):
    # 定义计算两点之间距离的内部函数
    def distance(x1,y1,x2,y2):
        return math.sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2))    
    # 将传入的位置坐标转换为列表形式
    position = position.tolist()
    # 对坐标点按照x轴和y轴的值进行排序
    for i in range(4):
        for j in range(i+1, 4):
            if(position[i][0] > position[j][0]):
                tmp = position[j]
                position[j] = position[i]
                position[i] = tmp
    if position[0][1] > position[1][1]:
        tmp = position[0]
        position[0] = position[1]
        position[1] = tmp

    if position[2][1] > position[3][1]:
        tmp = position[2]
        position[2] = position[3]
        position[3] = tmp

    # 根据排序后的坐标点计算裁剪区域的四个角点
    x1, y1 = position[0][0], position[0][1]
    x2, y2 = position[2][0], position[2][1]
    x3, y3 = position[3][0], position[3][1]
    x4, y4 = position[1][0], position[1][1]

    corners = np.zeros((4,2), np.float32)
    corners[0] = [x1, y1]
    corners[1] = [x2, y2]
    corners[2] = [x4, y4]
    corners[3] = [x3, y3]

    # 计算图像的宽度和高度
    img_width = distance((x1+x4)/2, (y1+y4)/2, (x2+x3)/2, (y2+y3)/2)
    img_height = distance((x1+x2)/2, (y1+y2)/2, (x4+x3)/2, (y4+y3)/2)

    # 定义转换后的坐标点
    corners_trans = np.zeros((4,2), np.float32)
    corners_trans[0] = [0, 0]
    corners_trans[1] = [img_width - 1, 0]
    corners_trans[2] = [0, img_height - 1]
    corners_trans[3] = [img_width - 1, img_height - 1]

    # 使用透视变换矩阵进行图像裁剪
    transform = cv2.getPerspectiveTransform(corners, corners_trans)
    dst = cv2.warpPerspective(img, transform, (int(img_width), int(img_height)))
    return dst

# 计算矩形框大小的函数
def calculate_size(box):
    return (box[2]-box[0]) * (box[3]-box[1])

# 计算两个矩形框交并比的函数
def calculate_iou(box1, box2):
    xA = max(box1[0], box2[0])
    yA = max(box1[1], box2[1])
    xB = min(box1[2], box2[2])
    yB = min(box1[3], box2[3])
    
    interArea = max(0, xB - xA) * max(0, yB - yA)
    box1Area = (box1[2] - box1[0]) * (box1[3] - box1[1])
    box2Area = (box2[2] - box2[0]) * (box2[3] - box2[1])
    unionArea = box1Area + box2Area - interArea
    iou = interArea / unionArea
    
    return iou

# 裁剪图像并保存的函数
def crop(image, box, i, text_data=None):
    image = Image.open(image)
    # 如果提供了文本数据,则在图像上绘制矩形框和文本
    if text_data:
        draw = ImageDraw.Draw(image)
        draw.rectangle(((text_data[0], text_data[1]), (text_data[2], text_data[3])), outline="red", width=5)
        # font_size = int((text_data[3] - text_data[1])*0.75)
        # font = ImageFont.truetype("arial.ttf", font_size)
        # draw.text((text_data[0]+5, text_data[1]+5), str(i), font=font, fill="red")

    cropped_image = image.crop(box)
    cropped_image.save(f"./temp/{i}.jpg")

# 判断一个点是否在矩形框内的函数
def in_box(box, target):
    if (box[0] > target[0]) and (box[1] > target[1]) and (box[2] < target[2]) and (box[3] < target[3]):
        return True
    else:
        return False

# 根据位置裁剪图像并保存的函数
def crop_for_clip(image, box, i, position):
    image = Image.open(image)
    w, h = image.size
    # 根据位置定义裁剪区域的边界
    if position == "left":
        bound = [0, 0, w/2, h]
    elif position == "right":
        bound = [w/2, 0, w, h]
    elif position == "top":
        bound = [0, 0, w, h/2]
    elif position == "bottom":
        bound = [0, h/2, w, h]
    elif position == "top left":
        bound = [0, 0, w/2, h/2]
    elif position == "top right":
        bound = [w/2, 0, w, h/2]
    elif position == "bottom left":
        bound = [0, h/2, w/2, h]
    elif position == "bottom right":
        bound = [w/2, h/2, w, h]
    else:
        bound = [0, 0, w, h]
    
    # 如果矩形框在定义的边界内,则裁剪并保存图像
    if in_box(box, bound):
        cropped_image = image.crop(box)
        cropped_image.save(f"./temp/{i}.jpg")
        return True
    else:
        return False

# 使用CLIP模型进行图像特征提取和相似度计算的函数
def clip_for_icon(clip_model, clip_preprocess, images, prompt):
    image_features = []
    # 对每张图像进行预处理并提取特征
    for image_file in images:
        image = clip_preprocess(Image.open(image_file)).unsqueeze(0).to(next(clip_model.parameters()).device)
        image_feature = clip_model.encode_image(image)
        image_features.append(image_feature)
    image_features = torch.cat(image_features)
    
    # 对文本提示进行编码
    text = clip.tokenize([prompt]).to(next(clip_model.parameters()).device)
    text_features = clip_model.encode_text(text)

    # 归一化图像和文本特征
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    # 计算图像和文本特征之间的相似度
    similarity = (100.0 * image_features @ text_features.T).softmax(dim=0).squeeze(0)
    _, max_pos = torch.max(similarity, dim=0)
    pos = max_pos.item()
    
    return pos
  • crop_image 函数:使用透视变换裁剪图像。
  • calculate_size 函数:计算矩形框的面积。
  • calculate_iou 函数:计算两个矩形框的交并比。
  • crop 函数:根据矩形框裁剪图像,并可选地在图像上绘制矩形框和文本。
  • in_box 函数:判断一个点是否在矩形框内。
  • crop_for_clip 函数:根据位置裁剪图像,并保存到指定路径。
  • clip_for_icon 函数:使用CLIP模型对图像进行特征提取,并计算图像与文本提示之间的相似度。

图标定位模块【处理图像识别和目标检测的结果】:

from MobileAgent.crop import calculate_size, calculate_iou
from PIL import Image
import torch

# 移除重叠的矩形框的函数
def remove_boxes(boxes_filt, size, iou_threshold=0.5):
    boxes_to_remove = set()

    # 遍历所有矩形框,移除面积过小的框
    for i in range(len(boxes_filt)):
        if calculate_size(boxes_filt[i]) > 0.05*size[0]*size[1]:
            boxes_to_remove.add(i)
    # 遍历所有矩形框,移除与其他框重叠过多的框
    for j in range(len(boxes_filt)):
        if calculate_size(boxes_filt[j]) > 0.05*size[0]*size[1]:
            boxes_to_remove.add(j)
        if i == j:
            continue
        if i in boxes_to_remove or j in boxes_to_remove:
            continue
        iou = calculate_iou(boxes_filt[i], boxes_filt[j])
        if iou >= iou_threshold:
            boxes_to_remove.add(j)

    # 过滤掉需要移除的矩形框
    boxes_filt = [box for idx, box in enumerate(boxes_filt) if idx not in boxes_to_remove]
    
    return boxes_filt

# 目标检测的函数
def det(input_image_path, caption, groundingdino_model, box_threshold=0.05, text_threshold=0.5):
    image = Image.open(input_image_path)
    size = image.size

    # 处理文本提示,确保以句号结尾
    caption = caption.lower()
    caption = caption.strip()
    if not caption.endswith('.'):
        caption = caption + '.'
    
    # 准备输入数据
    inputs = {
        'IMAGE_PATH': input_image_path,
        'TEXT_PROMPT': caption,
        'BOX_TRESHOLD': box_threshold,
        'TEXT_TRESHOLD': text_threshold
    }

    # 调用目标检测模型
    result = groundingdino_model(inputs)
    boxes_filt = result['boxes']

    # 将检测到的矩形框转换为图像坐标
    H, W = size[1], size[0]
    for i in range(boxes_filt.size(0)):
        boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])
        boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
        boxes_filt[i][2:] += boxes_filt[i][:2]

    # 将矩形框坐标转换为列表形式
    boxes_filt = boxes_filt.cpu().int().tolist()
    # 移除重叠的矩形框
    filtered_boxes = remove_boxes(boxes_filt, size)  # [:9]
    coordinates = []
    # 将过滤后的矩形框坐标添加到列表中
    for box in filtered_boxes:
        coordinates.append([box[0], box[1], box[2], box[3]])

    return coordinates
  • remove_boxes 函数:移除检测到的矩形框中面积过小或与其他框重叠过多的框。
  • det 函数:进行目标检测,处理输入图像和文本提示,调用目标检测模型,然后将检测到的矩形框转换为图像坐标,最后移除重叠的矩形框并返回最终的坐标列表。

提示模块:

def get_action_prompt(instruction, clickable_infos, width, height, keyboard, summary_history, action_history, last_summary, last_action, add_info, error_flag, completed_content, memory):
    prompt = "### Background ###\n"
    prompt += f"This image is a phone screenshot. Its width is {width} pixels and its height is {height} pixels. The user\'s instruction is: {instruction}.\n\n"
    
    prompt += "### Screenshot information ###\n"
    prompt += "In order to help you better perceive the content in this screenshot, we extract some information on the current screenshot through system files. "
    prompt += "This information consists of two parts: coordinates; content. "
    prompt += "The format of the coordinates is [x, y], x is the pixel from left to right and y is the pixel from top to bottom; the content is a text or an icon description respectively. "
    prompt += "The information is as follow:\n"

    for clickable_info in clickable_infos:
        if clickable_info['text'] != "" and clickable_info['text'] != "icon: None" and clickable_info['coordinates'] != (0, 0):
            prompt += f"{clickable_info['coordinates']}; {clickable_info['text']}\n"
    
    prompt += "Please note that this information is not necessarily accurate. You need to combine the screenshot to understand."
    prompt += "\n\n"
    
    prompt += "### Keyboard status ###\n"
    prompt += "We extract the keyboard status of the current screenshot and it is whether the keyboard of the current screenshot is activated.\n"
    prompt += "The keyboard status is as follow:\n"
    if keyboard:
        prompt += "The keyboard has been activated and you can type."
    else:
        prompt += "The keyboard has not been activated and you can\'t type."
    prompt += "\n\n"
    
    if add_info != "":
        prompt += "### Hint ###\n"
        prompt += "There are hints to help you complete the user\'s instructions. The hints are as follow:\n"
        prompt += add_info
        prompt += "\n\n"
    
    if len(action_history) > 0:
        prompt += "### History operations ###\n"
        prompt += "Before reaching this page, some operations have been completed. You need to refer to the completed operations to decide the next operation. These operations are as follow:\n"
        for i in range(len(action_history)):
            prompt += f"Step-{i+1}: [Operation: " + summary_history[i].split(" to ")[0].strip() + "; Action: " + action_history[i] + "]\n"
        prompt += "\n"
    
    if completed_content != "":
        prompt += "### Progress ###\n"
        prompt += "After completing the history operations, you have the following thoughts about the progress of user\'s instruction completion:\n"
        prompt += "Completed contents:\n" + completed_content + "\n\n"
    
    if memory != "":
        prompt += "### Memory ###\n"
        prompt += "During the operations, you record the following contents on the screenshot for use in subsequent operations:\n"
        prompt += "Memory:\n" + memory + "\n"
    
    if error_flag:
        prompt += "### Last operation ###\n"
        prompt += f"You previously wanted to perform the operation \"{last_summary}\" on this page and executed the Action \"{last_action}\". But you find that this operation does not meet your expectation. You need to reflect and revise your operation this time."
        prompt += "\n\n"
    
    prompt += "### Response requirements ###\n"
    prompt += "Now you need to combine all of the above to perform just one action on the current page. You must choose one of the six actions below:\n"
    prompt += "Open app (app name): If the current page is desktop, you can use this action to open the app named \"app name\" on the desktop.\n"
    prompt += "Tap (x, y): Tap the position (x, y) in current page.\n"
    prompt += "Swipe (x1, y1), (x2, y2): Swipe from position (x1, y1) to position (x2, y2).\n"
    if keyboard:
        prompt += "Type (text): Type the \"text\" in the input box.\n"
    else:
        prompt += "Unable to Type. You cannot use the action \"Type\" because the keyboard has not been activated. If you want to type, please first activate the keyboard by tapping on the input box on the screen.\n"
    prompt += "Home: Return to home page.\n"
    prompt += "Stop: If you think all the requirements of user\'s instruction have been completed and no further operation is required, you can choose this action to terminate the operation process."
    prompt += "\n\n"
    
    prompt += "### Output format ###\n"
    prompt += "Your output consists of the following three parts:\n"
    prompt += "### Thought ###\nThink about the requirements that have been completed in previous operations and the requirements that need to be completed in the next one operation.\n"
    prompt += "### Action ###\nYou can only choose one from the six actions above. Make sure that the coordinates or text in the \"()\".\n"
    prompt += "### Operation ###\nPlease generate a brief natural language description for the operation in Action based on your Thought."
    
    return prompt


def get_reflect_prompt(instruction, clickable_infos1, clickable_infos2, width, height, keyboard1, keyboard2, summary, action, add_info):
    prompt = f"These images are two phone screenshots before and after an operation. Their widths are {width} pixels and their heights are {height} pixels.\n\n"
    
    prompt += "In order to help you better perceive the content in this screenshot, we extract some information on the current screenshot through system files. "
    prompt += "The information consists of two parts, consisting of format: coordinates; content. "
    prompt += "The format of the coordinates is [x, y], x is the pixel from left to right and y is the pixel from top to bottom; the content is a text or an icon description respectively "
    prompt += "The keyboard status is whether the keyboard of the current page is activated."
    prompt += "\n\n"
    
    prompt += "### Before the current operation ###\n"
    prompt += "Screenshot information:\n"
    for clickable_info in clickable_infos1:
        if clickable_info['text'] != "" and clickable_info['text'] != "icon: None" and clickable_info['coordinates'] != (0, 0):
            prompt += f"{clickable_info['coordinates']}; {clickable_info['text']}\n"
    prompt += "Keyboard status:\n"
    if keyboard1:
        prompt += f"The keyboard has been activated."
    else:
        prompt += "The keyboard has not been activated."
    prompt += "\n\n"
            
    prompt += "### After the current operation ###\n"
    prompt += "Screenshot information:\n"
    for clickable_info in clickable_infos2:
        if clickable_info['text'] != "" and clickable_info['text'] != "icon: None" and clickable_info['coordinates'] != (0, 0):
            prompt += f"{clickable_info['coordinates']}; {clickable_info['text']}\n"
    prompt += "Keyboard status:\n"
    if keyboard2:
        prompt += f"The keyboard has been activated."
    else:
        prompt += "The keyboard has not been activated."
    prompt += "\n\n"
    
    prompt += "### Current operation ###\n"
    prompt += f"The user\'s instruction is: {instruction}. You also need to note the following requirements: {add_info}. In the process of completing the requirements of instruction, an operation is performed on the phone. Below are the details of this operation:\n"
    prompt += "Operation thought: " + summary.split(" to ")[0].strip() + "\n"
    prompt += "Operation action: " + action
    prompt += "\n\n"
    
    prompt += "### Response requirements ###\n"
    prompt += "Now you need to output the following content based on the screenshots before and after the current operation:\n"
    prompt += "Whether the result of the \"Operation action\" meets your expectation of \"Operation thought\"?\n"
    prompt += "A: The result of the \"Operation action\" meets my expectation of \"Operation thought\".\n"
    prompt += "B: The \"Operation action\" results in a wrong page and I need to return to the previous page.\n"
    prompt += "C: The \"Operation action\" produces no changes."
    prompt += "\n\n"
    
    prompt += "### Output format ###\n"
    prompt += "Your output format is:\n"
    prompt += "### Thought ###\nYour thought about the question\n"
    prompt += "### Answer ###\nA or B or C"
    
    return prompt


def get_memory_prompt(insight):
    if insight != "":
        prompt  = "### Important content ###\n"
        prompt += insight
        prompt += "\n\n"
    
        prompt += "### Response requirements ###\n"
        prompt += "Please think about whether there is any content closely related to ### Important content ### on the current page? If there is, please output the content. If not, please output \"None\".\n\n"
    
    else:
        prompt  = "### Response requirements ###\n"
        prompt += "Please think about whether there is any content closely related to user\'s instrcution on the current page? If there is, please output the content. If not, please output \"None\".\n\n"
    
    prompt += "### Output format ###\n"
    prompt += "Your output format is:\n"
    prompt += "### Important content ###\nThe content or None. Please do not repeatedly output the information in ### Memory ###."
    
    return prompt

def get_process_prompt(instruction, thought_history, summary_history, action_history, completed_content, add_info):
    prompt = "### Background ###\n"
    prompt += f"There is an user\'s instruction which is: {instruction}. You are a mobile phone operating assistant and are operating the user\'s mobile phone.\n\n"
    
    if add_info != "":
        prompt += "### Hint ###\n"
        prompt += "There are hints to help you complete the user\'s instructions. The hints are as follow:\n"
        prompt += add_info
        prompt += "\n\n"
    
    if len(thought_history) > 1:
        prompt += "### History operations ###\n"
        prompt += "To complete the requirements of user\'s instruction, you have performed a series of operations. These operations are as follow:\n"
        for i in range(len(summary_history)):
            operation = summary_history[i].split(" to ")[0].strip()
            prompt += f"Step-{i+1}: [Operation thought: " + operation + "; Operation action: " + action_history[i] + "]\n"
        prompt += "\n"
        
        prompt += "### Progress thinking ###\n"
        prompt += "After completing the history operations, you have the following thoughts about the progress of user\'s instruction completion:\n"
        prompt += "Completed contents:\n" + completed_content + "\n\n"
        
        prompt += "### Response requirements ###\n"
        prompt += "Now you need to update the \"Completed contents\". Completed contents is a general summary of the current contents that have been completed based on the ### History operations ###.\n\n"
        
        prompt += "### Output format ###\n"
        prompt += "Your output format is:\n"
        prompt += "### Completed contents ###\nUpdated Completed contents. Don\'t output the purpose of any operation. Just summarize the contents that have been actually completed in the ### History operations ###."
        
    else:
        prompt += "### Current operation ###\n"
        prompt += "To complete the requirements of user\'s instruction, you have performed an operation. Your operation thought and action of this operation are as follows:\n"
        prompt += f"Operation thought: {thought_history[-1]}\n"
        operation = summary_history[-1].split(" to ")[0].strip()
        prompt += f"Operation action: {operation}\n\n"
        
        prompt += "### Response requirements ###\n"
        prompt += "Now you need to combine all of the above to generate the \"Completed contents\".\n"
        prompt += "Completed contents is a general summary of the current contents that have been completed. You need to first focus on the requirements of user\'s instruction, and then summarize the contents that have been completed.\n\n"
        
        prompt += "### Output format ###\n"
        prompt += "Your output format is:\n"
        prompt += "### Completed contents ###\nGenerated Completed contents. Don\'t output the purpose of any operation. Just summarize the contents that have been actually completed in the ### Current operation ###.\n"
        prompt += "(Please use English to output)"
        
    return prompt
  • get_action_prompt:生成一个行动提示,用于指导AI助手根据用户指令、截图信息、键盘状态、历史操作、额外信息、记忆内容和错误标志来执行特定的操作。
  • get_reflect_prompt:生成一个反思提示,用于指导AI助手根据操作前后的截图信息、用户指令和额外信息来评估操作是否符合预期。
  • get_memory_prompt:生成一个记忆提示,用于指导AI助手根据洞见信息来回忆和输出相关内容。
  • get_process_prompt:生成一个处理提示,用于指导AI助手根据用户指令、历史思考和额外信息来更新已完成内容的总结。

文本定位模块(处理图像中的坐标点排序、计算字符串的最长公共子串长度,以及使用OCR技术提取图像中的文本信息):

import cv2
import numpy as np
from MobileAgent.crop import crop_image

# 对四边形的四个顶点进行排序的函数
def order_point(coor):
    # 将输入的坐标转换为4x2的numpy数组
    arr = np.array(coor).reshape([4, 2])
    # 计算所有顶点的中心点
    sum_ = np.sum(arr, 0)
    centroid = sum_ / arr.shape[0]
    # 计算每个顶点相对于中心点的角度
    theta = np.arctan2(arr[:, 1] - centroid[1], arr[:, 0] - centroid[0])
    # 根据角度对顶点进行排序
    sort_points = arr[np.argsort(theta)]
    # 调整排序后的顶点,确保第一个顶点在中心点的左侧
    sort_points = sort_points.reshape([4, -1])
    if sort_points[0][0] > centroid[0]:
        sort_points = np.concatenate([sort_points[3:], sort_points[:3]])
    sort_points = sort_points.reshape([4, 2]).astype('float32')
    return sort_points

# 计算两个字符串最长公共子串长度的函数
def longest_common_substring_length(str1, str2):
    m = len(str1)
    n = len(str2)
    # 初始化动态规划数组
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    # 填充动态规划数组
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if str1[i - 1] == str2[j - 1]:
                dp[i][j] = dp[i - 1][j - 1] + 1
            else:
                dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
    # 返回最长公共子串的长度
    return dp[m][n]

# 使用OCR技术提取图像中文本信息的函数
def ocr(image_path, ocr_detection, ocr_recognition):
    text_data = []
    coordinate = []
    
    # 读取完整的图像
    image_full = cv2.imread(image_path)
    # 进行文本检测
    det_result = ocr_detection(image_full)
    det_result = det_result['polygons'] 
    for i in range(det_result.shape[0]):
        # 对检测到的多边形顶点进行排序
        pts = order_point(det_result[i])
        # 根据排序后的顶点坐标裁剪图像
        image_crop = crop_image(image_full, pts)
        try:
            # 进行文本识别
            result = ocr_recognition(image_crop)['text'][0]
        except:
            continue

        # 将裁剪框的坐标转换为整数并存储
        box = [int(e) for e in list(pts.reshape(-1))]
        box = [box[0], box[1], box[4], box[5]]
        
        text_data.append(result)
        coordinate.append(box)
        
    else:
        return text_data, coordinate
  • order_point 函数:对四边形的四个顶点进行排序,以便正确地裁剪图像。
  • longest_common_substring_length 函数:使用动态规划算法计算两个字符串的最长公共子串长度。
  • ocr 函数:使用OCR技术提取图像中的文本信息。它首先使用ocr_detection函数检测图像中的文本区域,然后对每个检测到的区域进行排序和裁剪,最后使用ocr_recognition函数识别裁剪后的图像中的文本。

竞赛入门(第四周)

本期字数限制,下期更新!

百日文“新”(第五周)

天下苦Transformer久矣(Week1)

Transformer

  • 创新点
    • 核心思想:通过自注意力机制处理序列数据。
    • 大白话解释:用大剪子把序列数据按窗口进行截断,然后通过位置嵌入编码和自注意力计算简化了时序关系。【裁缝思想】
  • 优点
    • 减少了时间维度
    • 暴力拆解具有通用性,可以通过堆叠模块处理大数据。
  • 缺点
    • 没有由于升维,空间复杂度增大。
      在这里插入图片描述
      两个相乘的矩阵大小分别为(N × \times × d) 和(d × \times × N),我们需要拿第一个矩阵的每一行去与第二个矩阵的每一列做点乘,所以总共就需要 N2 次点乘。而每次点乘又需要 d 次乘法,所以总复杂度就为O(N2d)。

改进思路:

  • 魔改注意力机制:S4、FlashAttention
  • 提出了新的序列模型:Mamba

Mamba1

单层网络

在这里插入图片描述
输入是x,经过变换Wx+b和激活函数f,得到输出y。

RNN

因为序列数据就不太好用原始的神经网络处理了,所以RNN引入了隐状态h(hidden state)的概念,隐状态h可以对序列形的数据提取特征,接着再转换为输出。

在这里插入图片描述

  • h 1 h_1 h1基于上一个隐藏层的状态 h 0 h_{0} h0和当前的输入 x 1 x_{1} x1计算得来
  • 参数共享:RNN的权值是在同一个向量中,只是不同时刻而已。所以每一步使用的参数U、W、b都是一样的,也就是说每个步骤的参数都是共享的。

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

  • 细品:先根据输入 x t x_t xt和前一时刻的隐藏状态 h t − 1 h_{t-1} ht1计算出最新的隐藏状态 h t h_t ht,便可以根据最新的隐藏状态 h t h_t ht预测出 y t y_t yt
  • 存在问题:
    • 远距离的梯度消失,不适合长距离依赖;
    • 无法并行训练,导致训练成本较高;

SSM(时序状态空间模型 )

🔗推荐教程:https://www.youtube.com/watch?v=luCBXCErkCs

连续时间表示( continuous-time representation )
在这里插入图片描述

  • A就是存储着之前所有历史信息的浓缩精华(可以通过一系列系数组成的矩阵表示之),以基于A更新下一个时刻的空间状态hidden state
  • 与RNN循环结构: h t = t a n h ( W h t − 1 + U x t ) h_{t}=tanh \left(W h_{t-1}+U x_{t}\right) ht=tanh(Wht1+Uxt)非常类似:
  • A、B、C、D这4个矩阵是参数,是可以学习到的

统一方程后:
在这里插入图片描述

  • 1.假设我们有一些输入信号x(t),该信号首先乘以矩阵 B在这里插入图片描述
  • 2.上面第一步的结果,加上:上一个状态与矩阵A相乘(矩阵A描述了所有内部状态如何连接)的结果,用来更新状态state在这里插入图片描述
  • 3.然后,使用矩阵C来将状态转换为输出在这里插入图片描述
  • 4.最后,再利用矩阵D提供从输入到输出的直接信号,这通常也称为跳跃连接skip-connection在这里插入图片描述
  • 5.由于矩阵D类似于跳跃连接,因此在没有跳跃连接的情况下,SSM 通常被视为如下在这里插入图片描述
SSM升级到S4
  • 离散化:
    • 原因:文本序列是离散输入
    • 目的:将函数到函数 x ( t ) → y ( t ) x(t) \rightarrow y(t) x(t)y(t),而是序列到序列 x k → y k x_{k} \rightarrow y_{k} xkyk
    • 处理方式:利用零阶保持技术(Zero-order hold technique)【在对应变量上面加一个横杠】在这里插入图片描述
      • 1.保留每次收到离散信号,直到收到新的离散信号
      • 2.保持该值的时间由一个新的可学习参数表示,称为步长(size)—— Δ \Delta Δ ,它代表输入的阶段性保持(resolution)在这里插入图片描述
      • 3.有了连续的输入信号后,便可以生成连续的输出,并且仅根据输入的时间步长对值进行采样在这里插入图片描述
  • 循环结构表示
    • 目的:方便快速推理
    • 推导:
      • 展开每个时间步在这里插入图片描述
      • 对于y2,具体展开在这里插入图片描述
      • 推广:在这里插入图片描述
      • 由此,我们可以采用RNN的结构进行处理, h k h_k hk始终是 B ‾ x k \overline{\mathbf{B}} \mathbf{x}_{\mathrm{k}} Bxk A ‾ h k − 1 \overline{\mathbf{A}} \mathbf{h}_{\mathrm{k}-1} Ahk1的共同作用之下更新的,在这里插入图片描述
  • 卷积结构表示
    • 目的:方便并行训练
    • 卷积核(过滤器,kernels):处理的是文本而不是图像,因此我们需要一维视角在这里插入图片描述
      • 过滤器( K ‾ \overline{K} K)的理解:在这里插入图片描述
        • 1.使用 SSM 内核来检查每组token并计算输出在这里插入图片描述
        • 2.内核移动,执行下一步的计算在这里插入图片描述
        • 3.同理,继续移动在这里插入图片描述
        • 4.转化为点积的形式:
          y 2 = [ C A ˉ 2 B ˉ C A ˉ B ˉ C B ˉ ] ⋅ [ x 0 x 1 x 2 ] = K ˉ ⋅ X y_2=\left[ \begin{matrix} C\bar{A}^2\bar{B}& C\bar{A}\bar{B}& C\bar{B}\\ \end{matrix} \right] \cdot \left[ \begin{array}{c} x_0\\ x_1\\ x_2\\ \end{array} \right] =\bar{K}\cdot X y2=[CAˉ2BˉCAˉBˉCBˉ] x0x1x2 =KˉX
        • 5.由于其中三个离散参数A、B、C都是常数,因此我们可以预先计算左侧向量并将其保存为卷积核在这里插入图片描述

基于卷积与循环的利弊,我们采用两全其美的办法:

在这里插入图片描述
推理用RNN结构,训练用CNN结构在这里插入图片描述
💡这里我们用生活的例子,帮你更好理解RNN和SSM的区别,想象你在读一本书

  • 时序嵌套的 RNN每次只能读一行,然后把记忆传递到下一行,这种方法只适合处理短故事,故事一长,容易忘记前面的情节。
  • 而 SSM 并行处理,同时打开所有页看到每行内容,这样就能快速找到和理解整本书,无需逐行传递记忆
提出Mamaba

在这里插入图片描述
Mamaba这个名字来源于黑曼巴蛇(Black Mamba),以速度和致命著称。这种命名意在传达该架构的速度、灵活性和高效性,反映出它在处理和转换数据方面的强大能力。

  • 论文名称:Mamba: Linear-Time Sequence Modeling with Selective State Spaces,202405
  • 论文地址:https://arxiv.org/abs/2312.00752
  • 论文代码:https://github.com/state-spaces/mamba
  • 论文作者:Albert Gu, Tri Dao
  • 研究设计和结论
    • 问题:
      • Transformer的注意力机制虽然有效,但效率很低,因为它需要存储整个上下文(storing the entire context,也就是KV缓存),导致训练和推理消耗算力大。
      • RNN的推理和训练效率高,但性能容易受到对上下文压缩程度的限制。【写的快,忘的快】
      • SSM的问题在于其中的矩阵A B C不随输入不同而不同,即无法针对不同的输入针对性的推理
    • 方法:平衡一下,既要有抓重点的能力(选择性复制任务),也要有上下文联想/推理能力(诱导头任务)。在这里插入图片描述
      • 选择性SSM:保持时间维度,通过不同的阀门(矩阵 ( B t , B_t, Bt, C t C_t Ct) 和函数 ( Δ t (\Delta_t (Δt))调节记忆融合和特征捕捉。【大白话就是一个流淌的大水管子,通过设计不同的阀门来调节记忆的融合和特征的捕捉【水暖工旋钮思想在这里插入图片描述
      • 算法改进:左边是原来的算法,可以看出ABC和 Δ \Delta Δ都是固定的。现在分别用三个S 函数根据输入把 B、C、 Δ \Delta Δ都变成了时变的。在这里插入图片描述在这里插入图片描述
        • B:批次大小(Batch size)。表示一次输入的数据量的大小。类似于RNN中的输入门。
        • L:序列长度 (Sequence length)。表示每个序列中包含的时间步数。
        • N:特征维度(Feature dimension)。表示每个时间步的特征数量。
        • D:输入特征维度(nput feature dimension。
        • C:类似于 RNN 中的输出门。
        • A:对应这个维度的SSM来说,A在每个hidden state维度上的作用可以不相同,起到multi-scale/fine-grained gating的作用,这也是LSTM网络里面用element-wise product的原因。在这里插入图片描述
        • Δ \Delta Δ:就好比放大镜观察窗口,影响信息处理的焦点。较小的步长Δ会忽略当前输入,而更多地使用先前的上文,而较大的步长Δ会更多地关注当前输入而不是上文。在这里插入图片描述
        • 维度变化具体执行过程:
          • s B ( x ) = L i n e a r N ( x ) 、 s C ( x ) = L i n e a r N ( x ) s_B(x)=Linear_N(x)、s_C(x)=Linear_N(x) sB(x)=LinearN(x)sC(x)=LinearN(x)都是线性投影,这是种常见的神经网络操作,用于将输入数据转换到一个新的空间或维度。这里的 Linear 表示是用线性层来学习这几个函数。
          • s Δ ( x ) = B r o a d c a s t D ( L i n e a r 1 ( x ) ) s_\Delta(x) = Broadcast_D(Linear_1(x)) sΔ(x)=BroadcastD(Linear1(x)),广播是一个数操作,它使得维度较小的数组能够与维度较大的数组进行算术操作
          • τ Δ = s o f t p l u s τΔ= softplus τΔ=softplus,这是个平滑的非线性函数,通常用于网络中以添加非线性特征并帮助网络学习复杂的模式。
            S o f t p l u s ( x ) = l o g ( 1 + e x ) Softplus(x)=log(1+e^x) Softplus(x)=log(1+ex) 在这里插入图片描述
      • 流体力学与李指数映射:
        • Transformer 描述的是粒子运动,通过自注意力机制映射动态调整每个输入的权重,类似粒子间通过牛顿力学相互作用力来动态调整自己的轨迹。训练的过程,就是在用牛顿力学拟合粒子轨迹,每个输入 (粒子)独立计算与其他输入的关系。
        • Mamba 描述的是流体运动,通过李指数映射来建模时空结构。
          • 流体运动:流体运动描述的是连续介质中的分子集体行为,运动是整体的,内部各点之间有强烈的相互关系和依赖。流体的每个部分都受到整体流体运动的影响,通过内部压力、粘性等因素相互作用。
          • 记忆:记忆系统具有连续性、动态变化性和整体关联性,这些特性与流体的性质非常相似。流体模型能够更好地描述记忆中的信息如何相互关联、如何随着时间和新信息的出现进行动态调整和整合。
          • 李指数映射(Lieexponential map):李指数映射是一种数学工具,用于描述和分析一个向量场如何沿着另一个向量场发生变化,比如流体力学、电磁场、广义相对论的时空结构等解决了动态系统中相互作用的描述。它是群论和微分几何中重要的概念,来源于李群和李代数的理论,是挪威数学家索菲斯李引入的。训练 mamba 的过程就是用李指数映射拟合流体力学动态系统,找到主管道 A,调整阀门和旋钮 B t 、 C t 、 Δ t B_t、C_t、\Delta_t BtCtΔt,获得最优流体流动路径,让模型能在高维特征空间中进行高效导航和决策。
            • 把记忆的流淌比作一个水流管道系统,可以看做一个“李群”,进行各种复杂变换(比如旋转、推移等)。
            • 固定矩阵A就是主管道(全局演变路径),类似于流体运动的全局关系,让系统状态更新有固定的全局路径和规则。
            • B t 、 C t B_t、C_t BtCt就是阀门或旋钮, Δ \Delta Δ 这个离散化因子,就像是流体力学中的时间步长,决定流体运动的离散时间点。
      • 内核融合:把离散化和循环在GPU SRAM 内存中实现,然后加载和存储参数ABC矩阵都用HBM高带宽内存。【Flash Attention技术在这里插入图片描述
      • 并行扫描:
        • 原因:由于A、B、C这些矩阵现在是动态的了,因此无法使用卷积表示来计算它们(CNN需要固定的内核),因此,我们只能使用循环表示,如此也就而失去了卷积提供的并行训练能力
        • 扫描操作(scan operation):每个状态比如 H 1 H_1 H1都是前一个状态比如 H 0 H_0 H0乘以 A ˉ \bar{A} Aˉ,加上当前输入 X 1 X_1 X1乘以 B ˉ \bar{B} Bˉ的总和。这种状态之下想并行化是不可能的(因为只有在获取到前一个状态的情况下才能计算当前的每个状态)在这里插入图片描述
        • 大白话:每个状态H相当于一个人,看它对应的这条处理线程。自己忙活自己的,吃着碗里看着锅里,当别人的饭做好后,就抢过来用一下,所以某种程度上实现了并行计算。在这里插入图片描述
      • 重计算:避免存储反向传播所需的中间状态,在输入从HBM加载到SRAM时在反向通道中重计算。
      • 网络结构:在这里插入图片描述
        • 经过线性投影后,输入嵌入的维度可能会增加,以便让模型能够处理更高维度的特征空间,从而捕获更细致、更复杂的特征。
        • SSM之前的CNN负责提取局部特征(因其擅长捕捉局部的短距离特征),而SSM则负责处理这些特征并捕捉序列数据中的长期依赖关系,两者算互为补充
        • CNN有助于建立token之间的局部上下文关系,从而防止独立的token计算。如果每个 token 独立计算,那么模型就会丢失序列中 token 之间的上下文信息。若通过先进行卷积操作,可以确保在进入 SSM 之前,序列中的每个 token 已经考虑了其邻居 token 的信息。这样,模型就不会单独地处理每个 token,而是在处理时考虑了整个局部上下文。
    • 实验结果:
      • 任务对比:coping、selective copying、induction heads在这里插入图片描述
        • (左)复制任务:LTI的效果,输出只能对规则的输入特征进行发现,
        • (右上)选择性复制任务:能自己找重点了,带色的尽管开始间隔大小不一,但都能找出来排好队。
        • (右下)归纳头部任务:联想能力的体现再看到黑的后就想到以前后面应该跟着蓝色的。也就是说,对于非线性时变数据,具备了很强的特征捕捉能力。
      • 合成任务验证选择机制:
        在这里插入图片描述
        • 选择性复制任务要求模型能够记住并复制序列中的特定单词。表一Mamba架构与选择性机制结合后的表现优秀,(S6):准确率为99.8。
        • 扩展序列长度任务要求模型看到一个二元组(如“Harry Potter)时,能够记住“Harry并在序列中再次出现时预测“Potter”。表二Mamba架构,也就是最上面的棕色线,比其他方法要好两倍。
      • 语言模型预训练:在这里插入图片描述
        • 左图是较短的序列长度2048,右图是较长的序列长度8192。两图对比了不同模型,横轴为 FLOPs 计算复杂度由低到高,纵轴为困感度。可以看出Mamba 模型困感度最低。
        • 结论:它是第一个无需注意力机制就能在扩展定律scaling Laws 上匹敌强大Transformer++模型的架构。
      • DNA序列:由于大型语言模型的成功,人们开始探索将基础模型范式应用于基因组学。DNA被视为一种由有限词汇组成的离散序列,需要模型处理长程依赖。实验和图表展示了Mamba 架构在 DNA建模任务中的卓越性能,特别是在处理长序列和扩展模型大小方面。
        • 横轴为参数量,纵轴为困惑度。
        • Mamba (橙色线)的困感度:左图随着参数数量的增加(从约200K到约40M)显著下降,右图随着序列长度保持稳定。在这里插入图片描述
        • 下面是物种 DNA分类任务的微调准确率和不同数据集上的扩展定律。随序列长度,Mamba性能更好。在这里插入图片描述
      • 音频例子:音频波形建模和生成问题也是序列建模任务。在这里插入图片描述
      • 训练推理效率分析:在这里插入图片描述
        • 训练效率:Mamba的扫描实现比标准实现快40倍,处理长序列时时间增长最慢(橙线)。对比蓝线,最右边512 时,1ms /0.025ms = 40倍。
        • 推理效率:Mamba在推理阶段的吞吐量比Transformers高5倍,特别是在大批次处理时,显著优于其他模型。蓝色条的高度明显高于其他模型。
      • 消融实验: Δ t \Delta_t Δt是最重要的参数,其次是B和C的组合使用,这个好理解,选择性主要就靠它来确定几个函数 S B , S C , S Δ S_B, S_C, S_\Delta SB,SC,SΔ。随机初始化表现较好,复杂初始化效果较差。增大N(SSM状态维度)显著改善性能,成本增加微乎其微,但只有在B和C也选择性时才有效。这些发现验证了选择性SSM在语言建模任务中的有效性和优势。在这里插入图片描述
  • 论文贡献
    • 创新点:
      • 对输入信息有选择性处理(Selection Mechanism)
      • 硬件感知的算法(Hardware-aware Algorithm):该算法采用“并行扫描算法”而非“卷积”来进行模型的循环计算(使得不用CNN也能并行训练),但为了减少GPU内存层次结构中不同级别之间的IO访问,它没有具体化扩展状态
        当然,这点也是受到了S5(Simplified State Space Layers for Sequence Modeling)的启发
      • 更简单的架构:将SSM架构的设计与transformer的MLP块合并为一个块(combining the design of prior SSM architectures with the MLP block of Transformers into a single block),来简化过去的深度序列模型架构,从而得到一个包含selective state space的架构设计
    • 不足:模型堆叠能力差点意思

MambaOut

自注意力机制的类别

在这里插入图片描述

  • 因果模式:只能看过去,不能看未来,只有记忆没有未卜先知,适合自回归用来生成和预测,以史为鉴。如GPT。
  • 全可见模式:过去未来都可见,适合理解,左顾右看瞻前顾后。如Bert。

Mamba的选择性注意力机制属于因果模式,但和Transformer的因果注意力有区别。

在这里插入图片描述

  • Transformer的因果注意力(左图)是组合(叠加)之前所有的记忆,记忆无损但复杂度增加,越累越长,计算复杂度同样为0(L2);
  • Mamba的因果注意力(右图)是融合之前的记忆到新的隐藏状态,记忆有损但复杂度恒定。
Mamba适用的特征任务
  • 任务涉及处理长序列,因为复杂度低,更高效。
  • 任务需要因果token 混合模式。

💡此时我们反向思考,Mamba不适合短序列任务,且不需要考虑因果。

那么什么样的任务适合满足这种类型呢?视觉识别任务中的图像分类

视觉识别任务:

  • 图像分类:主要关注整体特征空间特征就够了,目标也只是粗犷的类别标号,因此不涉及什么序列信息,而且需要全局信息。
  • 目标检测和语义分割:不一定,比如要考虑边缘的连贯性,因此可能有序列问题。

因果模式没必要:
在这里插入图片描述
右图显示以ViT 为例,将自注意力机制从全可见模式切换到因果模式后,性能有所下降,说明对于图像分类问题,用因果模式没必要。


图像处理任务不属于长序列:

Transformer的浮点运算次数公式: F L O P s = 24 D 2 L + 4 D L 2 FLOPs =24D^2L +4DL^2 FLOPs=24D2L+4DL2

  • L是token 长度 (即输入序列的长度)
  • D是通道维度(即特征维度)
  • 24 D 2 L 24D^2L 24D2L代表线性复杂度
  • 4 D L 2 4DL^2 4DL2代表二次复杂度

如何判断它是不是需要长序列建模,只需要看计算量是不是对L敏感

Transformer的浮点运算次数公式推导:

在这里插入图片描述

自注意力部分:【X矩阵:L * D,权重矩阵W:D * D】

  • 查询矩阵: Q = X W q Q=XW_q Q=XWq
  • 键矩阵: K = X W k K=XW_k K=XWk
  • 值矩阵: V = X W v V=XW_v V=XWv

总的计算复杂度为: 3 L D 2 3LD^2 3LD2

注意力分数计算公式:【Q矩阵:L * D, K T K^T KT矩阵:D * L,V矩阵:L * D】
在这里插入图片描述
总的计算复杂度为: L 2 D L^2D L2D + L 2 D L^2D L2D = 2 L 2 D 2L^2D 2L2D

前馈神经网络通常包括两个线性层和一个激活函数计算复杂度为: 4 L D 2 + 4 L D 2 = 8 L D 2 4LD^2+4LD^2=8LD^2 4LD2+4LD2=8LD2

再考虑到多头注意力机制、残差连接、归一化以及前馈神经网络的多次计算等等,最后得到: F L O P s = 24 D 2 L + 4 D L 2 FLOPs =24D^2L +4DL^2 FLOPs=24D2L+4DL2

对于FLOPs(L),二次项与线性项的比率为: r L = 24 D 2 L 4 D L 2 = 6 D L r_L=\frac{24D^2L}{4DL^2}=\frac{6D}{L} rL=4DL224D2L=L6D,大于1说明二次项的计算复杂度超过了线性项,因此任务涉及长序列建模。


图像分类任务推导:

以VIT模型为例:

在这里插入图片描述
我们假设输入的图像大小为 224 * 224,patch大小为16 * 16

则 L(token数量)= ( 224 16 ) 2 (\frac{224}{16})^2 (16224)2 = 14 * 14 = 196。

每个这样的 patch 通道被展平成一个长向量,RGB三通道就是16163=768,然后通过一个线性投影层(粉色的)映射到高维空间,也就是通道维度 D,它是个指定的超参数。对于 ViT-S,常见的通道维度 D 为384。对图像分类任务L=196,远小于6D=6*384=2304,因此不涉及长序列建模。


目标检测和实例分割问题:

在 COCO 数据集上推理图像大小为 800x1280,生成的 token 数约为 4000,大于6D=6*384=2304,因此涉及长序列建模。

MambaOut

在这里插入图片描述

“What can I say, Mamba out.” — Kobe Bryant, NBA farewell speech, 2016

  • 论文名称:MambaOut: Do We Really Need Mamba for Vision?,20240520
  • 论文地址:https://arxiv.org/abs/2405.07992
  • 论文代码:
  • 论文作者:Weihao Yu, Xinchao Wang
  • 研究设计和结论
    • 问题: 解决注意力机制的二次复杂性问题。
    • 方法:通过堆叠 Mamba 块,同时移除其核心标记混合器 SSM,构建了一系列名为 MambaOut 的模型。在这里插入图片描述
      • 左图为MambaOut的架构:
        • 输入图像大小为 H × W × 3 {H}\times{W}\times{3} H×W×3,表示图像的高度、宽度和 RGB 三个颜色通道。
        • 采用了分层架构,共有四个阶段,每个阶段进行特征提取和降采样。
        • 每个阶段包含若干个 GatedCNN 块,用于特征提取。
        • 每个阶段之间有降采样操作,将特征图的大小逐渐减小,从而增加特征的抽象层次。通道维度为 D 1 , D 2 , D 3 , D 4 D_1,D_2,D_3,D_4 D1D2D3D4
      • 右图为Gated CNN 块,Gated CNN 块包含两个线性层、中间夹一个卷积层和归一化层,通过残差连接实现输入和输出的融合。和Mamba块的区别在于Gated CNN 块没有SSM(状态空间模型)在这里插入图片描述
      • MambaOut 的架构与 Swin TransformerDenseNet 在分层结构和降采样方面有相似之处,但在特征提取和信息混合机制上有所不同。
        • MambaOut 使用Gated CNN 块。
        • Swin Transformer 使用窗口注意力机制的Transformer块。在这里插入图片描述
        • DenseNet 则使用密集连接的卷积层。在这里插入图片描述
      • Gated CNN与ResNet的区别:在这里插入图片描述
        • Gated CNN 块使用了线性层进行升维操作,能够在特征空间中进行更灵活的变换,这与传统的 ResNet 中主要使用卷积操作进行特征提取有所不同。
        • Gated CNN 块中跳线增加了非线性激活函数可以被看作一种简单的门控机制,根据输入值调整输出信息量。增加了模型的非线性能力,使得模型能够学习更复杂的特征。
      • 核心代码实现:Gated CNN 块通过线性变换、卷积操作和残差连接,实现了对输入特征的扩展、局部特征提取和信息保留。结合了深度卷积网络和残差网络的优点,同时通过门控机制(如激活函数)来控制信息流。在这里插入图片描述
    • 实验结果:
      • 图像分类比较在这里插入图片描述
        • SSM 有没有对图像分类意义不大
        • 不如最新的CAFormer-M36 使用简单的可分离卷积和原始注意力机制,比所有同等大小的视觉Mamba 模型高出超过1%的准确率85.2%。
      • 目标检测和实例分割在这里插入图片描述
        • 数据集:COCO,Mask R-CNN的主干网络:MambaOut
        • Mamba适合处理长序列视觉任务:尽管MambaOut在COCO 上的目标检测和实例分割任务中可以超越一些视觉 Mamba模型,但它仍然落后于最先进的视觉 Mamba 模型,例如 VMamba 和LocalVMamba。
        • Mamba还需努力:与最先进的卷积-注意力混合模型 TransNext相比51.7%%,视觉Mamba仍表现出显著的性能差距49.2%。
      • 语义分割的比较:在这里插入图片描述
        • SSM 模块在语义分割任务上效果很好,同时也验证了MambaOut在某些情况下的有效性。视觉Mamba 需要进一步展示其在长序列建模任务中的强大性能,以在语义分割任务中实现更强的性能
  • 论文贡献
    • 创新点:
      • 定量分析论证了图像分类任务不是长序列建模问题,而目标检测和实例分割是。
      • 借鉴Mamba的GatedCNN 结构微调了 ResNet,实现了一种新型全局可见注意力机制下的改进版模型。
    • 不足:
      • 没有利用矩阵乘法单元,而现代加速器(如 GPU 和 TPU)正是为此而专门设计的。

Mamba2

Transformer is RNN
  • 文章:
    • Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention,2020
    • COSFORMER : RETHINKING SOFTMAX IN ATTENTION,2022
  • 介绍:引入线性注意力机制(Linear Transformer),将transformer的复杂度从O(N2)降低为O(N)。
  • softmax自注意力机制: Y = s o f t m a x ( Q K ⊤ ) ⋅ V Y=softmax(QK^⊤)⋅V Y=softmax(QK)V,其中 Q , K , V ∈ R ( 1 , P ) Q, K, V \in \mathbb{R}^{(1, P)} Q,K,VR(1,P),注意力机制需要一次次计算两两token之间的注意力,导致了二次方的计算复杂度。
    • 外层是对于每一个Query,我们需要计算它对应token的新表征。
    • 内层for循环是为了计算每一个Query对应的新表征,需要让该Query与每一个Key进行计算。
    • 大白话就是,军训时,甲乙丙丁4个人列成一队,计算注意力机制的过程相当于每个人都需要站出去算和自己在内所有人的相似度。
  • 问题:降低二次方的复杂度。
  • 方法:将softmax折叠到核特征映射中,并利用矩阵乘法的结合性将注意力计算中的矩阵左乘改成右乘。即 ( Q K ⊤ ) ⋅ V = Q ⋅ ( K ⊤ V ) \left(Q K^{\top}\right) \cdot V=Q \cdot\left(K^{\top} V\right) (QK)V=Q(KV)在这里插入图片描述
  • 推导:
    • 1.将自注意力的计算可以分解为向量运算: O = softmax ⁡ ( Q K T ′ d ) V = ∑ i = 1 T e q t ⊤ k i ⊙ v i ∑ i = 1 T e q t ⊤ k i O =\operatorname{softmax}\left(\frac{Q K^{T^{\prime}}}{\sqrt{d}}\right) V=\frac{\sum_{i=1}^{T} e^{q_{t}^{\top} k_{i}} \odot v_{i}}{\sum_{i=1}^{T} e^{q_{t}^{\top} k_{i}}} O=softmax(d QKT)V=i=1Teqtkii=1Teqtkivi在这里插入图片描述
    • 2.假设 s i m ( ) sim() sim()为抽象出的计算Query和Key相似度的函数,用下标i来表示矩阵的第 i 行(如 Q i Q_i Qi表示矩阵 Q 的第 i 行) O i = ∑ j = 1 N sim ⁡ ( Q i , K j ) ∑ j = 1 N sim ⁡ ( Q i , K j ) V j O_{i} =\frac{\sum_{j=1}^{N} \operatorname{sim}\left(Q_{i}, K_{j}\right) }{\sum_{j=1}^{N} \operatorname{sim}\left(Q_{i}, K_{j}\right)} V_{j} Oi=j=1Nsim(Qi,Kj)j=1Nsim(Qi,Kj)Vj
    • 3.Linear Transformer采用了kernel来定义 s i m ( ) sim() sim(),其中 ϕ \phi ϕ 是一个特征映射函数: sim ⁡ ( Q i , K j ) = ϕ ( Q i ) ϕ ( K j ) T \operatorname{sim}\left(Q_{i}, K_{j}\right)=\phi\left(Q_{i}\right) \phi\left(K_{j}\right)^{T} sim(Qi,Kj)=ϕ(Qi)ϕ(Kj)T,则 O i = ∑ j = 1 N ( ϕ ( Q i ) ϕ ( K j ) T ) V j ∑ j = 1 N ( ϕ ( Q i ) ϕ ( K j ) T ) O_{i}=\frac{\sum_{j=1}^{N}\left(\phi\left(Q_{i}\right) \phi\left(K_{j}\right)^{T}\right) V_{j}}{\sum_{j=1}^{N}\left(\phi\left(Q_{i}\right) \phi\left(K_{j}\right)^{T}\right)} Oi=j=1N(ϕ(Qi)ϕ(Kj)T)j=1N(ϕ(Qi)ϕ(Kj)T)Vj
    • 4.根据矩阵乘法结合律:【softmax只能左乘,linear可以右乘,而右乘更快】 ( ϕ ( Q ) ϕ ( K ) T ) V = ϕ ( Q ) ( ϕ ( K ) T V ) \left(\phi(Q) \phi(K)^{T}\right) V=\phi(Q)\left(\phi(K)^{T} V\right) (ϕ(Q)ϕ(K)T)V=ϕ(Q)(ϕ(K)TV),最终复杂度转化为: O ( N d 2 ) O\left(N d^{2}\right) O(Nd2)
    • 5.在一般的NLP任务中,一个头d的特征维度总是比输入序列长度 N ( d ≪ N ) N (d \ll N) N(dN)小得多,因此可以忽略d,实现O(N)的计算复杂度
    • 6.所以, O i = ϕ ( Q i ) ∑ j = 1 N ϕ ( K j ) T V j ϕ ( Q i ) ∑ j − 1 N ϕ ( K j ) T O_{i}=\frac{\phi\left(Q_{i}\right) \sum_{j=1}^{N} \phi\left(K_{j}\right)^{T} V_{j}}{\phi\left(Q_{i}\right) \sum_{j-1}^{N} \phi\left(K_{j}\right)^{T}} Oi=ϕ(Qi)j1Nϕ(Kj)Tϕ(Qi)j=1Nϕ(Kj)TVj
    • 7.我们可以再假设 S i = ∑ j = 1 i ϕ ( K j ) T V j = ϕ ( K i ) T V i + ∑ j = 1 i − 1 ϕ ( K j ) T V j = ϕ ( K i ) T V i + S i − 1 Z i = ∑ j = 1 i ϕ ( K j ) T = ϕ ( K i ) T + ∑ j = 1 i − 1 ϕ ( K j ) T = ϕ ( K i ) T + Z i − 1 \begin{array}{l} S_{i}=\sum_{j=1}^{i} \phi\left(K_{j}\right)^{T} V_{j}=\phi\left(K_{i}\right)^{T} V_{i}+\sum_{j=1}^{i-1} \phi\left(K_{j}\right)^{T} V_{j}=\phi\left(K_{i}\right)^{T} V_{i}+S_{i-1} \\ Z_{i}=\sum_{j=1}^{i} \phi\left(K_{j}\right)^{T}=\phi\left(K_{i}\right)^{T}+\sum_{j=1}^{i-1} \phi\left(K_{j}\right)^{T}=\phi\left(K_{i}\right)^{T}+Z_{i-1} \end{array} Si=j=1iϕ(Kj)TVj=ϕ(Ki)TVi+j=1i1ϕ(Kj)TVj=ϕ(Ki)TVi+Si1Zi=j=1iϕ(Kj)T=ϕ(Ki)T+j=1i1ϕ(Kj)T=ϕ(Ki)T+Zi1,当需要计算第 i 时刻的输出时,我们可以复用之前的状态 S i − 1 S_{i−1} Si1 Z i − 1 Z_{i−1} Zi1 ,再额外加上一个与当前时刻相关的计算量即可。而Transformer在计算第 i 时刻的输出时,它在第i-1个时刻的所有计算都无法被i时刻所复用。
  • 结论:Transformer is RNN
SSM is RNN

在这里插入图片描述

  • 左边就是一个简单的线性时不变系统建模
  • 中间是离散化后的模型就是个 RNN
  • 最右边是并行化用卷积核进行处理,也就是 CNN 化的模型。

这种表示方式是用图模型来建模,强调的是序列数据之间的依赖关系和动态变化。所谓的 SSM 其实可以理解为就是 RNN,只不过更强调通过线性代数方程来描述系统状态的变化,利用状态空间模型中的状态转移矩阵和观测知阵来进行建模。

对偶
  • 介绍:在数学、物理学乃至哲学中,“对偶性”是指两种看似不同的理论或模型之间存在的一种深层次的等价关系。通过这种对偶关系,可以将一个复杂的问题转化为另一个相对简单的问题来解决,或者在一种表示形式下无法轻易看到的性质在另一种表示形式下变得显而易见。
  • 思考:那么如何在注意力机制和SSM之间建立统一的对偶关联关系? => 结构化矩阵
结构化矩阵
  • 介绍:如果一个m×n阶矩阵只需要少于m×n个参数来描述,就是一个结构化矩阵(Structured Matrices)。如稀疏矩阵、低秩矩阵、Toeplitz矩阵、Cauchy矩阵、Vandermonde矩阵和蝶形矩阵。
    • Toeplitz 矩阵:每条对角线上的元素都相同的矩阵。在这里插入图片描述
    • Cauchy 矩阵:每个元素都由两个向量的元素之间的差的倒数来定义的矩阵。在这里插入图片描述
    • Vandermonde 矩阵:由一个向量的幂组成的矩阵。在这里插入图片描述
    • 低秩矩阵:其秩远小于其行或列数的矩阵。
  • 特性:压缩表示、通过快速算法直接操作这种压缩表示。
  • 目的:通过压缩表示可以用更少参数和更快算法计算,减少存储需求,加快运算速度。
  • SSM本质上也是一种结构化矩阵
    • SSM可以表示为 y=Mx 的形式,其中M是ABC的表达式。在这里插入图片描述
      • t = 0 时, h 0 = B 0 x 0 = ∑ s = 0 0 A 0 : s B s x s h_0=B_0x_0=∑_{s=0}^{0}A_{0:s}B_sx_s h0=B0x0=s=00A0:sBsxs A 0 : 0 A_{0:0} A0:0是单位矩阵 𝐼 】
      • t = 1 时, h 1 = A 1 h 0 + B 1 x 1 = A 1 B 0 x 0 + B 1 x 1 = ∑ s = 0 1 A 0 : s B s x s h_1=A_1h_0+B_1x_1=A_1B_0x_0+B_1x_1=∑_{s=0}^{1}A_{0:s}B_sx_s h1=A1h0+B1x1=A1B0x0+B1x1=s=01A0:sBsxs
      • t = 2 时, h 2 = A 2 h 1 + B 2 x 2 = A 2 ( A 1 B 0 x 0 + B 1 x 1 ) + B 2 x 2 = A 2 A 1 B 0 x 0 + A 2 B 1 x 1 + B 2 x 2 = ∑ s = 0 2 A 2 : s B s x s h_2=A_2h_1+B_2x_2=A_2(A_1B_0x_0+B_1x_1)+B_2x_2=A_2A_1B_0x_0+A_2B_1x_1+B_2x_2=∑_{s=0}^{2}A_{2:s}B_sx_s h2=A2h1+B2x2=A2(A1B0x0+B1x1)+B2x2=A2A1B0x0+A2B1x1+B2x2=s=02A2:sBsxs
      • 推广: h t = A t … A 1 B 0 x 0 + A t … A 2 B 1 x 1 + ⋯ + A t A t − 1 B t − 2 x t − 2 + A t B t − 1 x t − 1 + B t x t = ∑ t s = 0 A t : s B s x s h_t=A_t…A_1B_0x_0+A_t…A_2B_1x_1+⋯+A_tA_{t−1}B_{t−2}x_{t−2}+A_tB_{t−1}x_{t−1}+B_tx_t=\underset{s=0}{\overset{t}{∑}}A_{t:s}B_sx_s ht=AtA1B0x0+AtA2B1x1++AtAt1Bt2xt2+AtBt1xt1+Btxt=s=0tAt:sBsxs
      • y t = ∑ t s = 0 C t T A t : s B s x s y_t=\underset{s=0}{\overset{t}{∑}}C^{T}_{t}A_{t:s}B_sx_s yt=s=0tCtTAt:sBsxs
      • y = S S M ( A , B , C ) ( x ) = M x y=SSM(A,B,C)(x)=Mx y=SSM(A,B,C)(x)=Mx
      • SSM的矩阵变换形式: M j i : = C j T A j ⋯ A i + 1 B i M_{ji}:=C^{T}_{j}A_j⋯A_{i+1}B_i Mji:=CjTAjAi+1Bi
    • M 具有专门设计的半可分结构,能简化运算。
      • 顺序半可分矩阵 (SSS,SequentiallySemiseparable Structure):在这里插入图片描述
        • “半”:主要关注下三角部分。
        • “可分”:每个蓝色小块的秩较小,不超过N,意味着可以用更少的独立成分表示,从而实现高效计算。
        • 特点:序列化、下三角、低秩。
      • y = S S S ( A , B , C ) ⋅ x y = SSS(A, B, C)\cdot x y=SSS(A,B,C)x,状态空间模型SSM,如果状态的维度为N,等价于一个秩为N的SSS。也就是说任何SSM其实都可以转写成一个等价的局部下对角阵M的形式
  • 所以我们在SSM的计算中,特别是矩阵A,引入了类似注意力机制的公式和方法:
    在这里插入图片描述
    • 1.A的结构从对角线进一步简化为标量乘以单位矩阵结构。 在这种情况下,每个 A t A_t At也可以仅用一个标量来表示
    • 2.类似Transformer 中多头注意力的概念,使用了更大的头维度 P,相比于Mamba1中使用的 P = 1,通常选择P={64,128}
    • 3.使用类似注意力的对偶形式去除了 softmax,并引入了一个额外的掩码矩阵 L,根据数据生成,控制信息在时间上的传递量。在这里插入图片描述
      其中圆圈表示元素相乘,也就是哈德马积 a i a_i ai是依赖于输入的标量,范围在 [0, 1]之间。我们可以假设 a = [ a 1 , a 2 , a 3 ] a=[a_1,a_2,a_3] a=[a1,a2,a3],则在这里插入图片描述
      行标是i,列标是j,i<j的部分全是零,意味着只考虑时间上早于或同一时间点的元素之间的关系。换句话说,它是一种类似 GPT 模型中的单向注意力机制,只考虑过去的时间步,而不考虑未来的时间步。通过这种下三角矩阵,可以有效地控制信息在时间上的流动,确保信息只能从过去传递到现在,而不能反向传播。
张量收缩计算的对偶统一
  • SSM计算:SSS的计算过程可以被看作是一系列张量收缩操作,借助顺序半可分矩阵的特殊结构能实现高效计算。所以SSM理论上也可以这样做在这里插入图片描述
    第一步将输入矩阵X 与矩阵 B 进行结合,以产生一个中间结果 Z。矩阵A没有出现,它体现在第二步因状态更新中,L的定义依赖于 A。第三步是最终输出。在这里插入图片描述
  • 注意力机制计算:
    • 注意力基本形式:在这里插入图片描述
      • S 和 T 表示源和目标序列长度,分别意指:source、target之意
      • N表示特征维度
      • P表示头维度
    • 自注意力:
      • 源序列和目标序列相同(即 S = T)
      • 通常特征维度和头维度相同(即 N = P)
      • 𝑄, 𝐾, 𝑉是通过对同一输入向量的线性投影生成的,即 Q = W Q ⋅ X , K = W K ⋅ X , V = W V ⋅ X Q=W_Q⋅X,K=W_K⋅X,V=W_V⋅X Q=WQX,K=WKX,V=WVX
    • 掩码(核)注意力 :
      • y = ( L ∘ ( Q K ⊤ ) ) ⋅ V y=(L∘(QK^⊤))⋅V y=(L(QK))V
      • 将其分解为精确的计算序列:在这里插入图片描述先计算相似性矩阵 G = Q K T G=QK^T G=QKT,应用掩码矩阵 L 后得到新的相似性矩阵 M = G ⊙ L M=G\odot L M=GL,然后再计算最终的输出 Y = M V Y=MV Y=MV,后面对应的是相应的维度。
      • 写成张量收缩的形式: Y = c o n t r a c t ( T N , S N , S P , T S → T P ) ( Q , K , V , L ) Y = contract(TN, SN, SP, TS → TP) (Q, K, V , L) Y=contract(TN,SN,SP,TSTP)(Q,K,V,L)
      • 进一步拆解成多步收缩:在这里插入图片描述
    • 线性注意力
      • 用张量收缩表达: Y = ( L ∘ ( Q K ⊤ ) ) ⋅ V = Q ⋅ c u m s u m ( K T V ) Y=(L∘(QK^⊤))⋅V=Q⋅cumsum(K^TV) Y=(L(QK))V=Qcumsum(KTV)
      • 多步收缩变为了:在这里插入图片描述
    • 结构化掩码注意力SMA:
      • 无论是从状态空间模型(SSM)侧,还是从注意力机制侧来看,都可以统一到张量收缩的视角下进行操作。
      • 不同的掩码矩阵(如因果掩码、衰减掩码、Toeplitz 矩阵等)L 定义了不同的序列变换矩阵 M,从而实现不同形式的结构化注意力。在这里插入图片描述

于是,Mamba-2孕育而生。

Mamba2
  • 论文名称:Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Duality,202405
  • 论文地址:https://arxiv.org/abs/2405.21060
  • 论文代码:https://github.com/state-spaces/mamba
  • 论文作者:Tri Dao, Albert Gu
  • 研究设计和结论
    • 问题: 既然SSM和注意力机制两种对偶等价,那就研究如何结合起来进行计算。在这里插入图片描述
    • 方法:整体上是SSM,但是通过块分解把大矩阵拆解成小的子矩阵,每个小问题特别是低秩块,再用注意力机制计算,利用矩乘法上的高效性和并行计算能力使得计算过程更加高效。【分解块,独立运算,并行计算】
      • 如下图所示,一个大的 M矩阵,分解成9 块,其中蓝色块用矩阵乘法。在这里插入图片描述
      • 半可分矩阵 MMM 被分解成多个矩阵块:
        • 对角块(Diagonal Block)表示输入到输出的计算,
        • 低秩块 (Low-Rank Block)分成三类:
          • 从输入到状态 (lnput State)。
          • 从状态到状态 (State - State)。
          • 从状态到输出 (State ->Output)。在这里插入图片描述
      • 这种块分解方法进行计算的流程。在这里插入图片描述
        • 输入序列 X被分解成多个块,每个块对应图中的一个黑色虚线框。
        • 输入块通过低秩块 (绿色头)和对角块(橙色头)进行计算,得到中间的状态块 H。
        • 状态块之间通过低秩块(黄色箭头)进行计算,表示状态间的传递。
        • 最终,状态块通过对角块 (蓝色箭头)计算得到输出块 Y。
      • 核心代码在这里插入图片描述
        • segsum(x):用于计算分段累加和。
        • ssd(X,A,B,C,block len=64,initial_states=None):用于计算SSD 模型,其中A,B,C分别表示状态矩阵、扩展矩阵和收缩矩阵
          • 0.先对输入张量 X、A、B和C进行重排,将它们重排成块的形式
          • 1.对每个块内的对角块 (Diagonal Block)进行计算,使用 torch.einsum 计算块内的矩阵乘法。
          • 2.计算每个块内的低秩块 (Lw-Rank Block),用于生成下一个块的输入状态。
          • 3.生成块间的状态转移,确保在块边界处的状态正确。
          • 4.对块内的低秩块进行计算,将状态转换为输出。
          • 5.最后将块内和块间的输出汇总,得到最终输出Y和最终状态 final state。
        • 总结:通过块分解的方法,将一个大规模的状态空间模型问题分解成多个小规模的块级别运算问题。这种方法利用了半可分矩阵的特性,能够提高计算效率和并行性,适合硬件加速。
      • Mamba-2 架构:在这里插入图片描述
        • Mamba-1基于SSM (状态空间模型)设计,线性投影之后生成SSM 参数 A,B,C。
        • Mamba-2的两种块设计:
          • 顺序 Mamba 块。SSM 层从 A,X,B,C 直接映射到输出Y,在序列变换中并行生成参数 A,B,C。
          • 并行 Mamba块。SSM 层从 A,X,B,C 直接映射到输出Y,在块开始时并行生成,适合更大规模的并行处理。这种方法类似于标准注意力架构中的并行生成 Q,K,V。这种设计能够减少参数数量,适合更大模型的张量并行计算。
          • 每个 Mamba 块中增加额外归一化层,改善模型稳定性。
          • Mamba-2 模型通过并行化和增加归一化层来优化原始 Mamba 模型的计算效率和稳定性。
        • 并行化处理方法
          • 张量并行(Tensor Parallelism,简称TP)在这里插入图片描述
            • 左侧:输入和输出投影矩阵分割,并在单个设备上处理。每个SSM头(即A、B、C、X)到Y 的映射都在单个设备上进行。最终归一化层选择 GroupNorm,以避免额外的通信。
            • 右侧:将序列维度上的计算分配到多个设备上,每个设备负责一部分序列的计算,然后将结果传递给下一个设备。
          • 序列/上下文并行(Sequence/Context Parallelism):对于非常长的序列,可能需要沿序列长度维度将输入和激活拆分到不同的GPU上
            • 用于残差和归一化操作的序列并行(SP)。
            • 序列并行用于token混合操作(注意力或SSM),也称为“上下文并行”(context parallelism,简称CP)。
    • 实验结果:
      • 合成记忆任务:该图展示了不同模型在多查询关联记忆(MQAR)任务中的表现在这里插入图片描述
        • 三张子图对应不同的序列长度(256、512、1024)、横轴代表模型的维度 (32、64、128、256)、纵轴代表准确率(从0到1)。
        • Mamba-2 系列模型在较大的模型维度下表现优异,特别是当维度达到128和256时其准确率接近1.0。
        • Mamba-2 模型明显优于 Mamba-1 和普通注意力模型,尤其在更大的状态规模(N=256)下表现尤为显著。
      • 语言模型预训练和评估在这里插入图片描述
        • 在The Pile 上进行训练的模型,Mamba-2的性能匹配或超过了Mamba 和强大的“Transformer++”方案。
        • 与Transformer 基线相比,Mamba-2 在性能(困感度)、理论FLOPs和实际壁钟时间上都是帕累托占优的。
      • 零样本评估
        • 在每种模型规模中,Mamba-2 模型的表现普遍优于其他模型。特别是Mamba-2在较大的模型规模下(27B 参数)表现尤为突出,证明其在不同任务上的泛化能力更强。在这里插入图片描述
        • 在不同数量的注意力层下的困感度。大约10%的注意力层比例表现最佳。适量的注意力层可以显著提高模型性能,超过了完全不使用注意力层或完全使用注意力层的情况。在这里插入图片描述
        • 比较了SSD、MLP和注意力层的不同组合方式,在2.7B 规模上进行评估困感度(ppl)和准确率 (acc)。Mamba-2 与注意力层的结合 (后4个)在多个任务上的表现优于其他模型组合,显示出更强的泛化能力和任务适应性。在这里插入图片描述
      • 速度性能在这里插入图片描述
        • 左图:不同方法在处理序列长度(从512到512k)时所需的时间 (以毫秒为单位)。
        • 右图:展示了在处理固定序列长度(4K)时,不同状态维度(从16到56下所需的时间(以毫秒为单位)。
        • SSD 方法在处理大状态扩展时表现优异,比Mamba 的融合扫描快2到8倍(比如64k时紫色线1毫秒,Mamba为10 毫秒),并且在序列长度超过2k 时也比 FlashAttention-2更快。
      • 消融实验在这里插入图片描述
        • Mamba-2 模块在结合并行处理和额外归一化后,显著提升了模型性能,表现优于传统的Mamba-1 模块。
          在这里插入图片描述
        • 多头结构中,复杂的头组合和状态扩展通常可以提高性能,特别是当模型规模增大时对于核近似,Swish和LayerNorm 方法通常效果较好,且适用于不同规模的模型
        • 增加复杂度和头的数量一般有助于提高模型性能,但需要权衡参数数量的增加。
  • 论文贡献
    • 创新点:将 SSM 模型与注意力机制结合,实现了模型并行堆叠,提升了模型的堆叠能力。大白话就是一个流淌的大水管子,通过设计不同的阀门来调节记忆的融合和特征的捕捉。【水暖工旋钮思想】

xLSTM

LSTM
  • 大白话理解:因记忆能力有限,记住重要的,忘记无关紧要的。在这里插入图片描述
  • RNN —> LSTM:
    在这里插入图片描述
  • tanh 激活函数的作用:帮助调节流经网络的值,使得数值始终限制在 -1 和 1 之间。
    在这里插入图片描述
  • Sigmoid 激活函数的作用:帮助调节流经网络的值,使得数值始终限制在 0 和 1 之间。在这里插入图片描述
  • 公式:在这里插入图片描述
  • Forget Gate(遗忘门):
    • 在这里插入图片描述
  • Input Gate(输入门):
    • 在这里插入图片描述
  • Cell State(细胞状态):
    • 在这里插入图片描述
  • OutPut Gate(输出门):
    • 在这里插入图片描述
  • 贡献:一定程度上解决梯度消失和梯度爆炸的问题
  • 问题:
    • 在处理长序列时效率低;
    • 记忆容量有限;
    • 不能并行处理数据;
sLSTM
  • 公式:
    在这里插入图片描述
    • 输入门和遗忘门的激活函数从 sigmoid 改成了指数函数(红色部分)
      • 指数函数相比于 sigmoid 函数,具有更大的输出范围和更大的梯度(右图黄色,左图红色),可以减轻梯度消失问题,使得梯度在反向传播过程中不会迅速减小,从而使得模型在训练时能够更有效地更新权重。
      • 指数函数的增长速度比 sigmoid 函数快,对输入变化更加敏感。因此,可以更迅速地强烈的调整输入和遗忘门的输出,使得模型能够更快地捕捉到输入信息的变化,更加选择性地记住或忘记信息,从而提高模型的记忆和遗忘能力。
      • 这种强烈的选择性,让模型能够更准确地保留重要信息和丢弃不重要的信息。在特定任务(如长序列的最近邻搜索或稀有事件预测) 中表现得尤为显著,能够显著提升模型性能。
    • 引入了归一化状态 n t n_t nt(公式9),相当于搞了一个大分母,因为指数激活函数可能导致数值过大而溢出。
    • 相应的隐层 h t h_t ht的计算方式变了,改成了 c t / n t c_t/n_t ct/nt;也就是公式 (10)
    • 引入了一个额外状态 m_t 来进一步稳定门控在这里插入图片描述
      • 式子(15)使用了 log,指数函数的逆运算,相当于降一级运算,然后取最大值,意思就是输入门和遗忘门都别太猛。
      • 根据 m t m_t mt 再调整输入门和遗忘门,相当于设置了一个缓冲区。【忍一时风平浪静,退一步海阔天空】
  • 贡献:解决了敏感度,某种程度上也是长序列处理效率问题。
mLSTM
  • 公式:在这里插入图片描述
    • 状态和权重参数都变成了矩阵形式,对应的运算变成了向量矩阵乘法和哈达玛积,公式(21)。
    • 增加了 q t , k t , v t q_t, k_t,v_t qt,kt,vt这种键值对的计算公式(22-24),优化了子注意力机制,多了好几个权重模型增强了模型表达能力。
    • 这种框架可以使用多头模式,头与头之间没有记忆混合,因此可以充分并行,无形中提升了并行能力。
  • 贡献:增强LSTM的存储能力。
xLSTM

线性不可分

  • 设输入空间为X={(0,0),(1,1),(0,1),(1,0)},显然输入空间是二维的。对于异或问题,(0,0),(1,1)的值为0(图中用o表示),(0,1),(1,0)的值为1(图中用×表示),显然没有一条直线可以解决异或问题的分类在这里插入图片描述
  • 现在我们将输入空间中的任一点 ( x 1 , x 2 ) (x_1,x_2) (x1,x2),转换为 ( x 1 , x 2 , ( x 1 − x 2 ) 2 ) (x_1,x_2,(x_1-x_2)^2) (x1,x2,(x1x2)2),称转换后的空间为特征空间,记为Z,则Z={(0,0,0),(1,1,0),(0,1,1),(1,0,1)}。显然特征空间是三维的。(0,0,0),(1,1,0)的值为0(图中用o表示),均在底平面;(0,1,1),(1,0,1)的值为1(图中用×表示),均在顶平面。显然,任意一个在底平面和顶平面之间的平面可以解决异或问题的分类在这里插入图片描述

核方法 (Kernel Method)

  • 动机:解决分类任务中线性不可分的问题。
  • 方法:将数据映射到更高维的空间,将原本在低维空间中线性不可分的数据转换为在高维空间中线性可分的数据,以大大降低分类任务的难度。
  • 问题:更高维的数据就更加难以训练,也更容易过拟合
  • 应用:Transformer 模型就是通过多头注意力机制在高维空间中进行并行处理,使得不同位置的特征可以相互影响和结合,从而提高了模型的性能。

那么xLSTM如何升维?

  • “先干后变”:先在原始空间中总结信息,然后映射到高维空间,再返回原始空间。看图从下往上输入sLSTM,然后向上投影,也就是用一个倒着的梯形矩阵升维,处理后再降维。在这里插入图片描述

    • 内部结构:在这里插入图片描述
      • 输入先LN(LayerNorm)整理,然后一分为二。一部分卷积提取特征,激活非线性变换(左),另一部分直接输入sLSTM(右)
        • LN(LayerNorm):层归一化,帮助稳定和加速训练过程。
        • Conv4:卷积层,卷积核大小为 4。提取局部特征。
        • Swish:一种平滑的非线性激活函数,可以帮助模型学习到更复杂的模式。
      • 接着所有运算都采用了4个头的多头并行进制,每个头可以专注于捕捉输入数据的不同特征或模式,从而使模型能够更全面地理解数据。
        • 内部采用块对角线结构,在计算时可以并行处理,从而显著降低计算复杂度和内存需求;
        • 每个子矩阵(块)主要关注输入数据的一部分能够更好地捕捉局部特征;结构化的稀疏性,这有助于减少过拟合。
      • 在 sLSTM 图中的箭头表示信息在不同时间步之间的流动和处理,代表的是与先前时刻状态的混合计算。这部分相当于记忆的重新组合
      • 然后组内归一化、降维、再激活、再降维,然后与残差相加再输出。
        • GN(GroupNorm):组归一化(Group Normalization)。在每一组内进行归一化,有助于加速训练和提高模型稳定性,特别是在小批量(batch)训练时。
        • PF=3/4 和 PF=4/3:投影因子(Projection Factor),分别将输入维度缩小为原来的 3/4,将输入维度扩大 4/3 倍。
        • GeLU(Gaussian Error Linear Unit,高斯误差线性单元)也是一种激活函数,数学表达式为: G E L U ( x ) = x Φ ( x ) GELU(x)=xΦ(x) GELU(x)=xΦ(x),其中Φ(x)是输入x的标准正态累积分布函数(CDF),具体公式为:在这里插入图片描述
          图像为:
          在这里插入图片描述
          特性:
          • 平滑性:与ReLU的尖锐转折点不同,GELU提供了一个平滑的激活曲线,这有助于深度学习模型在训练过程中更加稳定,特别是在处理不连续输入数据时。
          • 非饱和性:GELU与ReLU一样,具有非饱和性质,这意味着它可以缓解梯度消失问题,特别是在训练深层网络时。
          • 自适应门控机制:GELU通过其内部的高斯CDF组件,实现了一种自适应的门控机制。这意味着GELU可以根据输入的属性自动调整激活的量,类似于神经元的开/关切换,这有助于网络自动学习重要的特征。
  • “先变后干” :先映射到高维空间,总结信息后再返回原始空间。也就是输入直接上投影,再用mLSTM处理,然后再降维。在这里插入图片描述

    • 高维空间中的记忆容量更大,因此用有矩阵化记忆单元的mLSTM更合适,而在低维空间处理sLSTM更合适。
    • 内部结构:在这里插入图片描述
      • 整体上都是充分利用了残差堆叠结构,层归一化技术等稳定网络,通过升降维度实现空间变换,激活函数非线性变换,然后利用 LSTM 进行记忆混合,或者说时序上的选择性自注意力机制计算,采用多头和块对角模式实现并行处理,当然也没少了用卷积提取特征。
      • PF=1/2 和 PF=2:投影因子(Projection Factor)。前者将输入维度缩小一半,后者将输入维度扩大两倍。
      • LSkip 是个跳线,类似于残差连接,可以帮助梯度更好地传递,防止梯度消失和爆炸。这里相当于有两种跳线残差。
      • BS=4:块大小为 4 的块对角投影矩阵。
      • mLSTM 单元中的q、k、v分别表示查询(query)、键(key)和值(value),我们刚讲过,都是从输入中生成的,用于计算注意力权重和进行信息检索。

实验论证:

  • 合成任务和长程任务在这里插入图片描述
    • 每行表示一种模型,包括Llama、Mamba等7种模型的12中变体,xLSTM[0:1]:主要是sLSTM块;xLSTM[1:0]:主要是mLSTM块;xLSTM[1:1]:均衡使用mLSTM和sLSTM块。
    • 每列表示一种任务,包括上下文敏感、确定性上下文无关、正则,最后是多数任务,也是正则。
    • 使用sLSTM和mLSTM的组合(如xLSTM[1:1])在大多数任务上表现出色,特别是在
      复杂和状态跟踪任务上。
    • 不同模型在多查询联想记忆任务中的性能对比:
      在这里插入图片描述
      • 图表说明:横轴模型的尺寸,纵轴验证准确率。
      • 实验结果:xLSTM[1:1](粉色)表现最佳,越难越好,Llama等Transformer模型在较小和中等难度任务中表现优越。Mamba略强。
  • 验证集困惑度比较 在这里插入图片描述
    • 数据集:15B个Token训练的 SlimPajama。
    • 比较内容:下一词预测性能。
    • 图表说明:横轴为模型参数量,纵轴为验证困惑度。
    • 实验结果:总体趋势都差不多,但xLSTM明显更好。说明其在语言建模任务中的优势。
  • 大规模语言建模实验在这里插入图片描述
    • 数据集:300B个Token训练的SlimPajama
    • 比较内容:不同模型在下一词预测任务中的验证困惑度(Validation Perplexity),特别是对长序列的外推性能。
    • 图表说明:横轴为token数量,也就是序列长度。纵轴为验证困惑度。
    • 实验结果:
      • 大多数模型在序列长度增加时验证困惑度上升。
      • xLSTM模型在较长序列上保持了较低的验证困惑度,显示了其在处理长序列时的优势。
  • 语言基准测试在这里插入图片描述
    • 数据集:300B个Token训练的SlimPajama
    • 比较内容:不同模型在下一词预测任务中的验证困惑度(Validation Perplexity)随参数数量变化的情况。
    • 图表说明:验证困惑度越低,表示模型的预测性能越好。
    • 实验结果:
      • 所有模型的验证困惑度随着参数数量的增加而下降,说明更大参数的模型在下一词预测任务上表现更好。
      • xLSTM模型(特别是xLSTM[7:1]和xLSTM[1:0])在所有参数数量下都表现出色,验证困惑度较低,说明其在语言建模任务中的性能优越。
      • xLSTM模型比Mamba表现好,而Mamba比Llama表现好。这表明xLSTM在处理大规模语言建模任务时,具有明显的优势。

代码实现:https://github.com/AI-Guru/xlstm-resources?tab=readme-ov-file

创新点:
在这里插入图片描述
借助指数门控混合记忆和新内存结构,LSTM增强为sLSTM和mLSTM。二者的结合构成了xLSTM模块,进一步堆叠可以实现大模型化。

TTT

  • 论文名称:Learning to (Learn at Test Time): RNNs with Expressive Hidden States,202407
  • 论文地址:https://arxiv.org/abs/2407.04620
  • 论文代码:https://github.com/test-time-training/ttt-lm-pytorch
  • 论文作者:Yu Sun, Xinhao Li, et al.
  • 关键词和摘要总结
    • 关键词:Test-Time Training (TTT), RNNs, Sequence Modeling
    • 摘要总结:提出TTT层,该层在保持线性复杂度的同时,通过将隐藏状态视为机器学习模型,并使用自监督学习更新规则,以提高长上下文建模的表达力。
  • 研究设计和结论
    • 问题: 虽然RNN有着线性的复杂度,但RNN在面对长下文时处理起来会比较困难。【RNN层必须将上下文压缩成固定大小的隐藏状态,其作为一种压缩启发式方法,更新规则需要发现成千上万甚至数百万个token之间的潜在结构和关系】
    • 方法
      • TTT层:在这里插入图片描述

        • 将固定大小的隐状态变量 ( s t s_t st) 更换为小的 MLP 网络,提升了表达能力和自适应性。
        • 传统的线性组合加激活函数更新,变为梯度下降法,实现边测试边学习。
        • 输出规则: z t = f ( x t ; W t ) z_{t}=f\left(x_{t} ; W_{t}\right) zt=f(xt;Wt)
          • 大白话理解:输出token z t z_t zt基于由 f 使用更新后的权重 W t W_t Wt进行预测 x t x_t xt
        • 更新规则: W t = W t − 1 − η ∇ ℓ ( W t − 1 ; x t ) W_{t}=W_{t-1}-\eta \nabla \ell\left(W_{t-1} ; x_{t}\right) Wt=Wt1η(Wt1;xt)
          • 引入并行化:因为 W t W_t Wt在两个地方依赖于 W t − 1 W_{t-1} Wt1无法并行化,一个是在减号前,一个是在 ∇ l \nabla l l内部。
          • 问题核心: ∇ l \nabla l l内部包含了大部分计算,需要重点对 ∇ l \nabla l l做并行化
          • 梯度下降GD:在这里插入图片描述
          • 批量梯度下降:对所有这些变量相对于 W 0 W_0 W0进行计算, G t = ∇ ℓ ( W 0 ; x t ) G_{t}=\nabla \ell\left(W_{0} ; x_{t}\right) Gt=(W0;xt),但 W t W_t Wt实际上只比 W 0 W_0 W0多一步梯度步长,导致批量梯度下降的有效搜索空间会相对比较小。
          • 在线梯度下降:在线梯度下降, W t W_t Wt距离 W 0 W_0 W0有t步之遥。
          • 小批量梯度下降:在批量梯度下降与在线梯度下降折中,即将批量大小设置为相对较小的b。 G t = ∇ ℓ ( W t ′ ; x t ) G_{t}=\nabla \ell\left(W_{t^{\prime}} ; x_{t}\right) Gt=(Wt;xt),其中 t ′ = t −   m o d   ( t , b ) t^{\prime}=t-\bmod (t, b) t=tmod(t,b)在这里插入图片描述
      • 训练:内外环嵌套训练

        • 外环:RNN 大网络,负责整体模型参数优化。包含一个 Task 对象,定义了模型参数和损失函数 MSE。 forward方法接收输入序列并遍历每个标记,调用 Learner 对象进行训练和预测。在整个网络的训练过程中, Task 类的参数( θ K \theta_K θK θ V \theta_V θV θ Q \theta_Q θQ )会被优化,以提高模型的整体性能。借鉴了 QKV 模式,但本质上并不是自注意力机制。在这里插入图片描述
        • 内环:MLP 小网络,负责具体训练和预测。在每个时间步内,Learner 对当前输入进行训练,计算损失函数的梯度并更新模型参数。 Learner 使用更新后的模型对当前输入进行预测,并返回预测结果。在这里插入图片描述
      • 模型变体:

        • 参数化模型:尝试了简单的线性模型和多层 MLP(Y轴),优化器方面尝试了使用所有数据的批量梯度下降和小批量梯度下降法(X轴)。 在这里插入图片描述TTT-linear 在使用批量梯度下降(Batch GD)时, TTT-Linear 的输出等价于线性注意力。
          • 对于TTT-Linear, f l i n ( x ) = W x f_{lin}(x)=Wx flin(x)=Wx,其中W是方阵。
          • 对于TTT-MLP, f M L P f_{MLP} fMLP有两层,类似于Transformer中的MLP。
          • 为了在TTT期间获得更好的稳定性,f总是包含层归一化(LN)和残差连接。即, f ( x ) = x + LN ⁡ ( f r e s ( x ) ) f(x)=x+\operatorname{LN}\left(f_{\mathrm{res}}(x)\right) f(x)=x+LN(fres(x)),其中 f r e s f_{res} fres可以是 f l i n f_{lin} flin f M L P f_{MLP} fMLP
        • 非参数化模型:如 Nadaraya-Watson 估计器,理论上等价于自注意力机制。【唬人的,连固定的模型和参数都不要,就是一种
          非线性加权求和运算,根据训练数据直接进行预测。】
    • 实验结果:TTT-Linear 和 TTT-MLP 在长上下文处理和推理延迟方面表现优于 Transformer 和 Mamba 模型。
      • Pile 数据集:左图 2k 上下文长度上,蓝色的 TTT-Linear(M)和红色 Mamba 表现相当,橙色的 TTT-MLP(M)略差。右图 8k 上下文长度:蓝色 TTT-Linear(M)和橙色 TTT-MLP(M)表现优于红色 Mamba。在这里插入图片描述
      • Books 数据集:
        • 左图2k 上下文长度:蓝色 TTT-Linear(M)和橙色 TTT-MLP(M)表现优于红色Mamba 和紫色 Transformer。 右图32k 上下文长度: TTT-Linear(M)和 TTT-MLP(M)在长上下文。在这里插入图片描述
        • 1k 到 32k 上下文长度所有上下文长度: TTT-Linear 和 TTT-MLP 表现稳定,长上下文下优于 Transformer 和 Mamba。在这里插入图片描述
      • TTT-MLP 的表现,橙色短上下文略差,越往右长上下文表现更好下表现更好在这里插入图片描述
      • NVIDIA A100 GPU 上的延迟:在这里插入图片描述
        • 前向延迟:蓝色 TTT-Linear 和红色 Mamba 延迟相同,绿色 Transformer 延迟较高。
        • 生成延迟:蓝色 TTT-Linear 和红色 Mamba 延迟低于绿色 Transformer 和橙色 TTT-MLP。
  • 论文贡献
    • 创新点:
      • 引入了边测试边学习的框架,提升了自适应能力。
      • 结合内外环模型的思路,实现了模型参数的宏观和微观优化。
      • 隐状态的神经网络化,增强了模型的表达能力和动态调整能力。
    • 不足:
      • 新的学习模式带来了新的问题,如初始化、学习时长和边学边干中的冲突和矛盾需要进一步研究
      • 非参数化 TTT 模型尚未进行实际的实验验证。

📜 参考资料 :

https://zhuanlan.zhihu.com/p/716838305
https://space.bilibili.com/1921388479
https://blog.csdn.net/v_JULY_v/article/details/134923301
https://blog.csdn.net/v_JULY_v/article/details/132178447
📌 [ 笔者 ]   文艺倾年
📃 [ 更新 ]   2024.8.27
❌ [ 勘误 ]   /* 暂无 */
📜 [ 声明 ]   由于作者水平有限,本文有错误和不准确之处在所难免,
              本人也很想知道这些错误,恳望读者批评指正!

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值