Ollama 支持的 flash attention 能提升推理速度吗?我们一起测测看吧

这是 Ollama 支持的 flash attention 能提升推理速度吗?我们一起测测看吧 的笔记哦,查看更详尽的内容,请观看视频,谢谢。

ollama 最近的更新还是蛮频繁的。继上次更新了并发请求之后,最新的版本 0.1.39 则是支持了 flash attention 。

flash attention 可以显著减少注意力机制计算的时间,使得Transformer模型在训练和推理时能够更快地处理大量数据,而且它通过优化内存访问模式,减少了计算过程中的内存占用。这不仅有助于降低内存开销,还减少了计算过程中的内存带宽瓶颈。

这里可以理解为在 attention 计算比较密集的情况下,flash attention 能够显著减少计算量,在推理时能够更快地处理大量数据。那么对应的,上下文越长,flash attention 所带来的优势就越明显。

这里我就对比下开关 flash attention 对推理时间的影响。

环境准备

这里还是使用一块 4090 云 GPU 进行测试。我还是用 llama3 模型。通过命令获取。

ollama pull llama3

然后通过下面的命令可以启动 ollama:

OLLAMA_HOST=0.0.0.0:6006 OLLAMA_NUM_PARALLEL=16 OLLAMA_MODELS=/root/autodl-tmp/models ollama serve

其中:

  1. OLLAMA_HOST=0.0.0.0:6006:指定 Ollama 服务监听的主机和端口,0.0.0.0 表示监听所有网络接口上的连接,端口为 6006
  2. OLLAMA_NUM_PARALLEL=16:设置 Ollama 服务的并行处理数量为 16,表示可以同时处理 16 个请求。
  3. OLLAMA_MODELS=/root/autodl-tmp/models:指定 Ollama 模型的存储路径为 /root/autodl-tmp/models

默认 ollama 并没有打开 flash attention 需要我们在启动的时候,增加环境变量 OLLAMA_FLASH_ATTENTION=1 来启动:

OLLAMA_FLASH_ATTENTION=1 \
OLLAMA_NUM_PARALLEL=16 \
OLLAMA_HOST=0.0.0.0:6006 \
OLLAMA_MODELS=/root/autodl-tmp/models \
ollama serve

修改 open webui 的配置,直接使用云端的 ollama

首先,我在本地通过 docker 已经运行了一个 open webui。然后点击左下角 User -> Settings -> Connections ,修改 Ollama Base URL 为云端 ollama 的地址。

在这里插入图片描述
然后准备一个特别长的提示词:

下面是一系列的新闻:


<此处略过 1000 字>

请你帮我进行分析和处理,做如下事情:

  1. 帮我做总结
  2. 帮我依据以上内容准备 5 个问题

这里我先插入了一段从 csdn 摘录的新闻,然后让 llama3 帮我进行总结,并准备 5 个问题。这也是目前非常流行的 RAG 的场景:先提供一个上下文,然后基于上下文所提供的信息提问。

然后就可以通过 OLLAMA_FLASH_ATTENTION=1 的有无去对比下实际效果了。

脚本准备

下一步我还是使用之前的脚本对 flash attention 进行测试,但是这次我需要一个更长的上下文。

import aiohttp
import asyncio
import time
from tqdm import tqdm

question = """
下面是一系列的新闻:

---

<此处省略 1000 字,如果你想要使用这个脚本,务必把这里替换为一个很长的新闻稿>

---

请你帮我进行分析和处理,做如下事情:

1. 帮我做总结
2. 帮我依据以上内容准备 5 个问题
"""

async def fetch(session, url):
    """
    参数:
        session (aiohttp.ClientSession): 用于请求的会话。
        url (str): 要发送请求的 URL。
    
    返回:
        tuple: 包含完成 token 数量和请求时间。
    """
    start_time = time.time()

    # 请求的内容
    json_payload = {
        "model": "llama3",
        "messages": [{"role": "user", "content": question}],
        "stream": False,
        "temperature": 0.7 # 参数使用 0.7 保证每次的结果略有区别
    }
    async with session.post(url, json=json_payload) as response:
        response_json = await response.json()
        end_time = time.time()
        request_time = end_time - start_time
        completion_tokens = response_json['usage']['completion_tokens'] # 从返回的参数里获取生成的 token 的数量
        return completion_tokens, request_time

async def bound_fetch(sem, session, url, pbar):
    # 使用信号量 sem 来限制并发请求的数量,确保不会超过最大并发请求数
    async with sem:
        result = await fetch(session, url)
        pbar.update(1)
        return result

async def run(load_url, max_concurrent_requests, total_requests):
    """
    通过发送多个并发请求来运行基准测试。
    
    参数:
        load_url (str): 要发送请求的URL。
        max_concurrent_requests (int): 最大并发请求数。
        total_requests (int): 要发送的总请求数。
    
    返回:
        tuple: 包含完成 token 总数列表和响应时间列表。
    """
    # 创建 Semaphore 来限制并发请求的数量
    sem = asyncio.Semaphore(max_concurrent_requests)
    
    # 创建一个异步的HTTP会话
    async with aiohttp.ClientSession() as session:
        tasks = []
        
        # 创建一个进度条来可视化请求的进度
        with tqdm(total=total_requests) as pbar:
            # 循环创建任务,直到达到总请求数
            for _ in range(total_requests):
                # 为每个请求创建一个任务,确保它遵守信号量的限制
                task = asyncio.ensure_future(bound_fetch(sem, session, load_url, pbar))
                tasks.append(task)  # 将任务添加到任务列表中
            
            # 等待所有任务完成并收集它们的结果
            results = await asyncio.gather(*tasks)
        
        # 计算所有结果中的完成token总数
        completion_tokens = sum(result[0] for result in results)
        
        # 从所有结果中提取响应时间
        response_times = [result[1] for result in results]
        
        # 返回完成token的总数和响应时间的列表
        return completion_tokens, response_times

if __name__ == '__main__':
    import sys

    if len(sys.argv) != 3:
        print("Usage: python bench.py <C> <N>")
        sys.exit(1)

    C = int(sys.argv[1])  # 最大并发数
    N = int(sys.argv[2])  # 请求总数

    # vllm 和 ollama 都兼容了 openai 的 api 让测试变得更简单了
    url = 'http://localhost:11434/v1/chat/completions'

    start_time = time.time()
    completion_tokens, response_times = asyncio.run(run(url, C, N))
    end_time = time.time()

    # 计算总时间
    total_time = end_time - start_time
    # 计算每个请求的平均时间
    avg_time_per_request = sum(response_times) / len(response_times)
    # 计算每秒生成的 token 数量
    tokens_per_second = completion_tokens / total_time

    print(f'Performance Results:')
    print(f'  Total requests            : {N}')
    print(f'  Max concurrent requests   : {C}')
    print(f'  Total time                : {total_time:.2f} seconds')
    print(f'  Average time per request  : {avg_time_per_request:.2f} seconds')
    print(f'  Tokens per second         : {tokens_per_second:.2f}')

通过命令 python bench.py 16 48 分别测试开启 flash attention 和关闭 flash attention 的结果即可。

  • 23
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值