一、定义
- 定义
- paged attention_v2 使用
- vllm 中 flash attention 代码
- xformer中flash attention 代码
- paged attention 、flash attention 比较
- nsight compute 分析paged attention
- profiler 分析paged attention
二、实现 - 定义
- paged attention_v2 使用
vllm 中paged attention 是使用cuda 使用的,如下图。
python 调用
使用代码:
import random
from typing import List, Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip
from vllm.utils import (create_kv_caches_with_random)
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
NUM_BLOCKS = 4321
PARTITION_SIZE = 512
def test_paged_attention(
version: str,
num_seqs: int,
num_heads: Tuple[int, int],
head_size: int,
use_alibi: bool,
block_size: int,
dtype: torch.dtype,
kv_cache_dtype: str,
seed: int,
device: str,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
query.uniform_(-scale, scale)
assert num_query_heads % num_kv_heads == 0
alibi_slopes = None
if use_alibi:
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
seq_lens[-1] = MAX_SEQ_LEN
max_seq_len = max(seq_lens)
seq_lens = torch.tensor(seq_lens, dtype=torch.int)
# Create the block tables.
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables_lst: List[List[int]] = []
for _ in range(num_seqs):
block_table = [
random.randint(0, NUM_BLOCKS - 1)
for _ in range(max_num_blocks_per_seq)
]
block_tables_lst.append(block_table)
block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
# # Create the KV caches.
key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, block_size, 1,
num_kv_heads, head_size,
kv_cache_dtype, dtype, seed,
device)
key_cache, value_cache = key_caches[0], value_caches[0]
# Using default kv_scale
kv_scale = 1.0
# Call the paged attention kernel.
output = torch.empty_like(query)
if version == "v1":
ops.paged_attention_v1(
output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
)
elif version == "v2":
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
assert PARTITION_SIZE % block_size == 0
num_seqs, num_heads, head_size = output.shape
tmp_output = torch.empty(
size=(num_seqs, num_heads, num_partitions, head_size),
dtype=output.dtype,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, num_partitions),
dtype=torch.float32,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
)
else:
raise AssertionError(f"Unknown version: {version}")
print(output)
test_paged_attention(
version="v1",
num_seqs=7,
num_heads=(40, 40),
head_size=64,
use_alibi=False,
block_size=16,
dtype=torch.bfloat16,
kv_cache_dtype="auto",
seed=0,
device="cuda:0",)
- vllm 中flash attention 代码
import torch
from typing import List,Optional
from vllm_flash_attn import flash_attn_with_kvcache
head_size=128
num_heads=(16, 16)
kv_lens=[1328, 18, 463]
NUM_BLOCKS=16
block_size=16
dtype=torch.float16
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_seqs = len(kv_lens)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5
query = torch.randn((num_seqs, num_query_heads, head_size), dtype=dtype)
key_cache = torch.randn((NUM_BLOCKS, block_size, num_kv_heads, head_size), dtype=dtype)
value_cache = torch.randn_like(key_cache)
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32)
output = flash_attn_with_kvcache(
q=query.unsqueeze(1), #[3,16,128]
k_cache=key_cache, #32768, 16, 16, 128
v_cache=value_cache, #32768, 16, 16, 128
softmax_scale=scale, # 0.08838834764831845
causal=True,
block_table=block_tables, #[3, 83]
cache_seqlens=kv_lens_tensor, #[1328, 18, 463]
).squeeze(1)
print(output)
- xformer中flash attention 代码
import torch
import random
from typing import List, Optional, Tuple
import torch
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from vllm.utils import get_max_shared_memory_bytes, is_hip
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
@torch.inference_mode()
def test_multi_query_kv_attention(
num_seqs: int,
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
max_len = min(MAX_SEQ_LEN, 4096)
seq_lens = random.sample(range(1, max_len), num_seqs)
num_tokens = sum(seq_lens)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
qkv = torch.empty(num_tokens,
num_query_heads + 2 * num_kv_heads,
head_size,
dtype=dtype)
qkv.uniform_(-scale, scale)
query, key, value = qkv.split(
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)
num_queries_per_kv = num_query_heads // num_kv_heads
if num_queries_per_kv > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
output = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
attn_bias=attn_bias,
p=0.0,
scale=scale,
)
res =test_multi_query_kv_attention(
num_seqs=3,
num_heads=(40 ,40),
head_size=64,
dtype=torch.float16,
seed=0,
device="cuda:0",
)
- paged attention 、flash attention 比较
用时比较:paged attention v2<paged attention <flash attention
内存占用 paged attention v2>paged attention >flash attention
import random
from typing import List, Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip
from vllm.utils import (create_kv_caches_with_random)
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from xformers import ops as xops
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
print(MAX_SEQ_LEN) #41216
NUM_BLOCKS = 4321
PARTITION_SIZE = 512
seed=0
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device("cuda:0")
import time
def test_paged_attention(
version: str,
num_seqs: int,
num_heads: Tuple[int, int],
head_size: int,
use_alibi: bool,
block_size: int,
dtype: torch.dtype,
kv_cache_dtype: str,
seed: int,
device: str,
) -> None:
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
query.uniform_(-scale, scale)
alibi_slopes = None
print(MAX_SEQ_LEN)
seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
seq_lens[-1] = MAX_SEQ_LEN
max_seq_len = max(seq_lens)
seq_lens = torch.tensor(seq_lens, dtype=torch.int)
# Create the block tables.
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables_lst: List[List[int]] = []
for _ in range(num_seqs):
block_table = [
random.randint(0, NUM_BLOCKS - 1)
for _ in range(max_num_blocks_per_seq)
]
block_tables_lst.append(block_table)
block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
# # Create the KV caches.
key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, block_size, 1,
num_kv_heads, head_size,
kv_cache_dtype, dtype, seed,
device)
key_cache, value_cache = key_caches[0], value_caches[0]
# Using default kv_scale
kv_scale = 1.0
output = torch.empty_like(query)
if version == "v1":
start=time.time()
ops.paged_attention_v1(
output,
query, #[7,40,64]
key_cache, #[4321, 40, 8, 16, 8]
value_cache, #[4321, 40, 64, 16]
num_kv_heads, #40
scale, #0.125
block_tables, #[7,256]
seq_lens,
block_size,
max_seq_len,
alibi_slopes, #None
kv_cache_dtype,
kv_scale,
)
end=time.time()
print(f"paged attention:{end-start}")
elif version == "v2":
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
assert PARTITION_SIZE % block_size == 0
num_seqs, num_heads, head_size = output.shape
tmp_output = torch.empty(
size=(num_seqs, num_heads, num_partitions, head_size),
dtype=output.dtype,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, num_partitions),
dtype=torch.float32,
)
max_logits = torch.empty_like(exp_sums)
start = time.time()
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
)
end = time.time()
print(f"paged attention v2:{end - start}")
else:
raise AssertionError(f"Unknown version: {version}")
return output
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from vllm.utils import get_max_shared_memory_bytes, is_hip
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
@torch.inference_mode()
def test_multi_query_kv_attention(
num_seqs: int,
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
seed: int,
device: str,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
max_len = min(MAX_SEQ_LEN, 4096)
seq_lens = random.sample(range(1, max_len), num_seqs)
num_tokens = sum(seq_lens)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
qkv = torch.empty(num_tokens,
num_query_heads + 2 * num_kv_heads,
head_size,
dtype=dtype)
qkv.uniform_(-scale, scale)
query, key, value = qkv.split(
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)
num_queries_per_kv = num_query_heads // num_kv_heads
if num_queries_per_kv > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
start=time.time()
output = xops.memory_efficient_attention_forward(
query.unsqueeze(0),
key.unsqueeze(0),
value.unsqueeze(0),
attn_bias=attn_bias,
p=0.0,
scale=scale,
)
end = time.time()
print(f"flash attention:{end - start}")
torch.cuda.empty_cache()
res =test_multi_query_kv_attention(
num_seqs=7,
num_heads=(40 ,40),
head_size=64,
dtype=torch.float16,
seed=0,
device="cuda:0",
)
output1=test_paged_attention(
version="v1",
num_seqs=7, #序列长度
num_heads=(40, 40),
head_size=64,
use_alibi=False,
block_size=16,
dtype=torch.bfloat16,
kv_cache_dtype="auto",
seed=0,
device="cuda:0",)
torch.cuda.empty_cache()
output2=test_paged_attention(
version="v1",
num_seqs=7,
num_heads=(40, 40),
head_size=64,
use_alibi=False,
block_size=16,
dtype=torch.bfloat16,
kv_cache_dtype="auto",
seed=0,
device="cuda:0",)
torch.cuda.empty_cache()
output2=test_paged_attention(
version="v2",
num_seqs=7,
num_heads=(40, 40),
head_size=64,
use_alibi=False,
block_size=16,
dtype=torch.bfloat16,
kv_cache_dtype="auto",
seed=0,
device="cuda:0",)
torch.cuda.empty_cache()
output2=test_paged_attention(
version="v2",
num_seqs=7,
num_heads=(40, 40),
head_size=64,
use_alibi=False,
block_size=16,
dtype=torch.bfloat16,
kv_cache_dtype="auto",
seed=0,
device="cuda:0",)
torch.cuda.empty_cache()
res =test_multi_query_kv_attention(
num_seqs=7,
num_heads=(40 ,40),
head_size=64,
dtype=torch.float16,
seed=0,
device="cuda:0",
)
torch.cuda.empty_cache()
res =test_multi_query_kv_attention(
num_seqs=7,
num_heads=(40 ,40),
head_size=64,
dtype=torch.float16,
seed=0,
device="cuda:0",
)
- nsight compute 分析paged attention
test.py 文件:
import random
from typing import List, Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip
from vllm.utils import (create_kv_caches_with_random)
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
print(MAX_SEQ_LEN) #41216
NUM_BLOCKS = 4321
PARTITION_SIZE = 512
seed=0
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device("cuda:0")
import time
def test_paged_attention(
version: str,
num_seqs: int,
num_heads: Tuple[int, int],
head_size: int,
use_alibi: bool,
block_size: int,
dtype: torch.dtype,
kv_cache_dtype: str,
seed: int,
device: str,
) -> None:
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
query.uniform_(-scale, scale)
alibi_slopes = None
print(MAX_SEQ_LEN)
seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
seq_lens[-1] = MAX_SEQ_LEN
max_seq_len = max(seq_lens)
seq_lens = torch.tensor(seq_lens, dtype=torch.int)
# Create the block tables.
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables_lst: List[List[int]] = []
for _ in range(num_seqs):
block_table = [
random.randint(0, NUM_BLOCKS - 1)
for _ in range(max_num_blocks_per_seq)
]
block_tables_lst.append(block_table)
block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
# # Create the KV caches.
key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, block_size, 1,
num_kv_heads, head_size,
kv_cache_dtype, dtype, seed,
device)
key_cache, value_cache = key_caches[0], value_caches[0]
# Using default kv_scale
kv_scale = 1.0
output = torch.empty_like(query)
if version == "v1":
start=time.time()
ops.paged_attention_v1(
output,
query, #[7,40,64]
key_cache, #[4321, 40, 8, 16, 8]
value_cache, #[4321, 40, 64, 16]
num_kv_heads, #40
scale, #0.125
block_tables, #[7,256]
seq_lens,
block_size,
max_seq_len,
alibi_slopes, #None
kv_cache_dtype,
kv_scale,
)
end=time.time()
print(f"paged attention:{end-start}")
elif version == "v2":
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
assert PARTITION_SIZE % block_size == 0
num_seqs, num_heads, head_size = output.shape
tmp_output = torch.empty(
size=(num_seqs, num_heads, num_partitions, head_size),
dtype=output.dtype,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, num_partitions),
dtype=torch.float32,
)
max_logits = torch.empty_like(exp_sums)
start = time.time()
ops.paged_attention_v2(
output,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
seq_lens,
block_size,
max_seq_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
)
end = time.time()
print(f"paged attention v2:{end - start}")
else:
raise AssertionError(f"Unknown version: {version}")
print(torch.cuda.max_memory_allocated("cuda:0") / 1024 ** 2) # 192M
print("=============================")
print(torch.cuda.memory_allocated("cuda:0") / 1024 ** 2) # 128M
return output
output1=test_paged_attention(
version="v1",
num_seqs=7, #序列长度
num_heads=(40, 40),
head_size=64,
use_alibi=False,
block_size=16,
dtype=torch.bfloat16,
kv_cache_dtype="auto",
seed=0,
device="cuda:0",)
torch.cuda.empty_cache()
测试:>> /home/vllm-main/tests/kernels# ncu --set full -o paged_atten python test.py
客户端加载分析:
1. 屋檐模型分析
分析结论:
1. paged attention v1 计算强度4.66<内存强度15.59,属于内存访问密集任务。
2. paged attention v1 的带宽L1、L2、DRAM 分别为6.43、17.03、15.59。
3. paged attention v1 位于带宽瓶颈区,模型的计算强度<平台的计算强度。模型的吞吐量<平台的吞吐量。
4. paged attention v2 计算强度26.04<75.58, 仍然属于访问密集任务。
5. paged attention v2 的带宽L1、L2、DRAM 分别为20.91、76.81、75.58。
6. v2 比v1 计算强度与吞吐量均有提高,执行速度降低。
paged attention v1 优化方向:提高计算强度、提高模型的吞吐速度。
paged attention v2 优化方向:仍可提高。
待续
-
带宽分析
L1/TEX Hit Rate:L1 命中率
L2 Hit Rate:L2 命中率
L2 Compression Success Rate: L2 压缩成功率
Mem Busy:内存利用率
Mem Pipes Busy: 内存管道忙碌状态
L2 Compression Ratio: L2压缩率
memory chart: 内存图表直观的显示了数据访问流向以及命中率。从左向右看,在内核中进行计算。- 优化方向1: 从L1TEX到L2的存储器的存储器访问模式不是最佳的。对L2的L1TEX请求的粒度是128字节的缓存行。即每个L2请求有4个连续的32字节扇区。然而,这个内核平均只访问每个缓存行可能的4个扇区中的1.5个扇区。检查源计数器部分是否有未加密的存储,并尽量减少每个 内存请求需要访问的缓存行数。
- 优化方向2:从设备存储器加载的存储器访问模式导致从DRAM读取55330588个扇区,这是导致L2缓存中丢失的55318297个扇区的1.0倍。L2中读取未命中的DRAM读取粒度为64字节,即L2缓存行的下半部分或上半部分。尝试更改访问模式,以利用DRAM读取请求返回的两个扇区,从而优化DRAM吞吐量的使用。对于跨步内存读取,避免64字节或更大的跨步,以避免将未使用的扇区从DRAM移动到L2。
-
内存分析
1 wrap occupancy(或通常称为Occupancy)是一个关键的性能指标,用于衡量GPU利用率和并行度。具体来说,Occupancy是指在一个流多处理器(Streaming Multiprocessor,简称SM)中,处于活跃状态(即非等待状态)的warp数量与SM支持的最大活跃warp数量的比值。这里的“warp”是GPU执行线程的基本单位,通常包含多个线程(如32个线程),这些线程在GPU上并行执行相同的指令。
2 Occupancy直接反映了GPU的硬件利用率。高Occupancy通常意味着GPU的并行处理资源得到了有效利用,可以执行更多的计算任务。
3. 虽然高Occupancy不总是代表高性能(因为还可能受到其他因素的影响,如内存带宽、缓存效率等),但低Occupancy通常会降低性能,因为它限制了GPU隐藏延迟的能力,导致计算单元在等待数据时处于空闲状态。
4. Occupancy 是指一个 SM 中 active warps 数量和最大可能的 active warps 数量的比值。
影响Occupancy的因素
- 寄存器使用:每个线程使用的寄存器数量会影响可并发执行的warp数量。如果每个线程使用的寄存器过多,那么SM上能够同时执行的warp数量就会减少,从而降低Occupancy。
- 共享内存使用:每个block使用的共享内存也会影响Occupancy。共享内存是有限的资源,如果每个block使用的共享内存过多,那么能够同时执行的block数量就会减少,进而降低Occupancy。
- block和warp的配置:block的大小(即每个block中的线程数)和warp的大小(即每个warp中的线程数)也会影响Occupancy。通过合理配置这些参数,可以在不同场景下优化Occupancy。
如何优化Occupancy
- 减少寄存器使用:通过优化代码,减少每个线程使用的寄存器数量,可以增加同时执行的warp数量,从而提高Occupancy。
- 合理配置block和warp:根据具体的应用场景和GPU的硬件特性,合理配置block的大小和warp的大小,以达到最佳的Occupancy。
- 优化内存访问:通过优化数据布局和访问模式,减少内存访问延迟和冲突,可以提高GPU的缓存效率和内存带宽利用率,从而间接提高Occupancy。
- 上图分析:attention v1 与attention v2 采用相同的数量的寄存器、Block块,但v1 占用更多的共享内存,从而导致每个GPU 利用率降低。
7 profiler 分析paged attention
with torch.profiler.profile(on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/example'),
record_shapes=True,
profile_memory=True,
with_stack=True):
output1=test_paged_attention(
version="v1",
num_seqs=7, #序列长度
num_heads=(40, 40),
head_size=64,
use_alibi=False,
block_size=16,
dtype=torch.bfloat16,
kv_cache_dtype="auto",
seed=0,
device="cuda:0",)
torch.cuda.empty_cache()
加载模型>>(base) D:\cnki_1\model_predict>tensorboard --logdir=D:/log