pageAttention学习

15 篇文章 1 订阅
3 篇文章 0 订阅

一、定义

  1. 定义
  2. paged attention_v2 使用
  3. vllm 中 flash attention 代码
  4. xformer中flash attention 代码
  5. paged attention 、flash attention 比较
  6. nsight compute 分析paged attention
  7. profiler 分析paged attention
    二、实现
  8. 定义
  9. paged attention_v2 使用
    vllm 中paged attention 是使用cuda 使用的,如下图。
    在这里插入图片描述
    python 调用在这里插入图片描述
    使用代码:
import random
from typing import List, Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip
from vllm.utils import (create_kv_caches_with_random)

FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
NUM_BLOCKS = 4321  
PARTITION_SIZE = 512
def test_paged_attention(
        version: str,
        num_seqs: int,
        num_heads: Tuple[int, int],
        head_size: int,
        use_alibi: bool,
        block_size: int,
        dtype: torch.dtype,
        kv_cache_dtype: str,
        seed: int,
        device: str,
) -> None:
        random.seed(seed)
        torch.random.manual_seed(seed)
        if torch.cuda.is_available():
                torch.cuda.manual_seed(seed)
        torch.set_default_device(device)
        scale = float(1.0 / (head_size**0.5))
        num_query_heads, num_kv_heads = num_heads
        query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
        query.uniform_(-scale, scale)

        assert num_query_heads % num_kv_heads == 0
        alibi_slopes = None
        if use_alibi:
                alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)

        seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
        seq_lens[-1] = MAX_SEQ_LEN
        max_seq_len = max(seq_lens)
        seq_lens = torch.tensor(seq_lens, dtype=torch.int)

        # Create the block tables.
        max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
        block_tables_lst: List[List[int]] = []
        for _ in range(num_seqs):
                block_table = [
                        random.randint(0, NUM_BLOCKS - 1)
                        for _ in range(max_num_blocks_per_seq)
                ]
                block_tables_lst.append(block_table)

        block_tables = torch.tensor(block_tables_lst, dtype=torch.int)

        # # Create the KV caches.
        key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, block_size, 1,
                                                                num_kv_heads, head_size,
                                                                kv_cache_dtype, dtype, seed,
                                                                device)
        key_cache, value_cache = key_caches[0], value_caches[0]
        # Using default kv_scale
        kv_scale = 1.0

        # Call the paged attention kernel.
        output = torch.empty_like(query)
        if version == "v1":
                ops.paged_attention_v1(
                        output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        kv_scale,
                )
        elif version == "v2":
                num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
                assert PARTITION_SIZE % block_size == 0
                num_seqs, num_heads, head_size = output.shape
                tmp_output = torch.empty(
                        size=(num_seqs, num_heads, num_partitions, head_size),
                        dtype=output.dtype,
                )
                exp_sums = torch.empty(
                        size=(num_seqs, num_heads, num_partitions),
                        dtype=torch.float32,
                )
                max_logits = torch.empty_like(exp_sums)
                ops.paged_attention_v2(
                        output,
                        exp_sums,
                        max_logits,
                        tmp_output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        kv_scale,
                )
        else:
                raise AssertionError(f"Unknown version: {version}")
        print(output)

test_paged_attention(
        version="v1",
        num_seqs=7,
        num_heads=(40, 40),
        head_size=64,
        use_alibi=False,
        block_size=16,
        dtype=torch.bfloat16,
        kv_cache_dtype="auto",
        seed=0,
        device="cuda:0",)
  1. vllm 中flash attention 代码
import torch
from typing import List,Optional
from vllm_flash_attn import flash_attn_with_kvcache

head_size=128
num_heads=(16, 16)
kv_lens=[1328, 18, 463]
NUM_BLOCKS=16
block_size=16
dtype=torch.float16
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_seqs = len(kv_lens)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)

scale = head_size**-0.5
query = torch.randn((num_seqs, num_query_heads, head_size), dtype=dtype)
key_cache = torch.randn((NUM_BLOCKS, block_size, num_kv_heads,  head_size),  dtype=dtype)
value_cache = torch.randn_like(key_cache)
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)

max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,  NUM_BLOCKS,   (num_seqs, max_num_blocks_per_seq),  dtype=torch.int32)

output = flash_attn_with_kvcache(
        q=query.unsqueeze(1),      #[3,16,128]
        k_cache=key_cache,          #32768, 16, 16, 128
        v_cache=value_cache,         #32768, 16, 16, 128
        softmax_scale=scale,         # 0.08838834764831845
        causal=True,
        block_table=block_tables,     #[3, 83]
        cache_seqlens=kv_lens_tensor,    #[1328,   18,  463]
    ).squeeze(1)

print(output)
  1. xformer中flash attention 代码
import torch
import random
from typing import List, Optional, Tuple

import torch
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from vllm.utils import get_max_shared_memory_bytes, is_hip

FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512


@torch.inference_mode()
def test_multi_query_kv_attention(
        num_seqs: int,
        num_heads: Tuple[int, int],
        head_size: int,
        dtype: torch.dtype,
        seed: int,
        device: str,
) -> None:
        random.seed(seed)
        torch.random.manual_seed(seed)
        if torch.cuda.is_available():
                torch.cuda.manual_seed(seed)
        torch.set_default_device(device)

        max_len = min(MAX_SEQ_LEN, 4096)
        seq_lens = random.sample(range(1, max_len), num_seqs)
        num_tokens = sum(seq_lens)
        scale = float(1.0 / (head_size**0.5))
        num_query_heads, num_kv_heads = num_heads
        qkv = torch.empty(num_tokens,
                          num_query_heads + 2 * num_kv_heads,
                          head_size,
                          dtype=dtype)
        qkv.uniform_(-scale, scale)
        query, key, value = qkv.split(
                [num_query_heads, num_kv_heads, num_kv_heads], dim=1)

        num_queries_per_kv = num_query_heads // num_kv_heads
        if num_queries_per_kv > 1:
                # Handle MQA and GQA
                key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
                value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
        attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
        output = xops.memory_efficient_attention_forward(
                query.unsqueeze(0),
                key.unsqueeze(0),
                value.unsqueeze(0),
                attn_bias=attn_bias,
                p=0.0,
                scale=scale,
        )

res =test_multi_query_kv_attention(
        num_seqs=3,
        num_heads=(40 ,40),
        head_size=64,
        dtype=torch.float16,
        seed=0,
        device="cuda:0",
)
  1. paged attention 、flash attention 比较
    用时比较:paged attention v2<paged attention <flash attention
    内存占用 paged attention v2>paged attention >flash attention
import random
from typing import List, Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip
from vllm.utils import (create_kv_caches_with_random)
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from xformers import ops as xops

FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
print(MAX_SEQ_LEN)        #41216
NUM_BLOCKS = 4321
PARTITION_SIZE = 512
seed=0
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
torch.set_default_device("cuda:0")
import time
def test_paged_attention(
        version: str,
        num_seqs: int,
        num_heads: Tuple[int, int],
        head_size: int,
        use_alibi: bool,
        block_size: int,
        dtype: torch.dtype,
        kv_cache_dtype: str,
        seed: int,
        device: str,
) -> None:

        scale = float(1.0 / (head_size**0.5))
        num_query_heads, num_kv_heads = num_heads
        query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
        query.uniform_(-scale, scale)

        alibi_slopes = None

        print(MAX_SEQ_LEN)
        seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
        seq_lens[-1] = MAX_SEQ_LEN
        max_seq_len = max(seq_lens)
        seq_lens = torch.tensor(seq_lens, dtype=torch.int)

        # Create the block tables.
        max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
        block_tables_lst: List[List[int]] = []
        for _ in range(num_seqs):
                block_table = [
                        random.randint(0, NUM_BLOCKS - 1)
                        for _ in range(max_num_blocks_per_seq)
                ]
                block_tables_lst.append(block_table)

        block_tables = torch.tensor(block_tables_lst, dtype=torch.int)

        # # Create the KV caches.
        key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, block_size, 1,
                                                                num_kv_heads, head_size,
                                                                kv_cache_dtype, dtype, seed,
                                                                device)
        key_cache, value_cache = key_caches[0], value_caches[0]
        # Using default kv_scale
        kv_scale = 1.0

        output = torch.empty_like(query)
        if version == "v1":
                start=time.time()
                ops.paged_attention_v1(
                        output,
                        query,           #[7,40,64]
                        key_cache,       #[4321, 40, 8, 16, 8]
                        value_cache,     #[4321, 40, 64, 16]
                        num_kv_heads,    #40
                        scale,           #0.125
                        block_tables,    #[7,256]
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,     #None
                        kv_cache_dtype,
                        kv_scale,
                )
                end=time.time()
                print(f"paged attention:{end-start}")
        elif version == "v2":
                num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
                assert PARTITION_SIZE % block_size == 0
                num_seqs, num_heads, head_size = output.shape
                tmp_output = torch.empty(
                        size=(num_seqs, num_heads, num_partitions, head_size),
                        dtype=output.dtype,
                )
                exp_sums = torch.empty(
                        size=(num_seqs, num_heads, num_partitions),
                        dtype=torch.float32,
                )
                max_logits = torch.empty_like(exp_sums)
                start = time.time()
                ops.paged_attention_v2(
                        output,
                        exp_sums,
                        max_logits,
                        tmp_output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        kv_scale,
                )
                end = time.time()
                print(f"paged attention v2:{end - start}")
        else:
                raise AssertionError(f"Unknown version: {version}")


        return output



from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
from vllm.utils import get_max_shared_memory_bytes, is_hip

FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# This will change depending on the compute capability.
# - 512 as a buffer
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512


@torch.inference_mode()
def test_multi_query_kv_attention(
        num_seqs: int,
        num_heads: Tuple[int, int],
        head_size: int,
        dtype: torch.dtype,
        seed: int,
        device: str,
) -> None:
        random.seed(seed)
        torch.random.manual_seed(seed)
        if torch.cuda.is_available():
                torch.cuda.manual_seed(seed)
        torch.set_default_device(device)

        max_len = min(MAX_SEQ_LEN, 4096)
        seq_lens = random.sample(range(1, max_len), num_seqs)
        num_tokens = sum(seq_lens)
        scale = float(1.0 / (head_size**0.5))
        num_query_heads, num_kv_heads = num_heads
        qkv = torch.empty(num_tokens,
                          num_query_heads + 2 * num_kv_heads,
                          head_size,
                          dtype=dtype)
        qkv.uniform_(-scale, scale)
        query, key, value = qkv.split(
                [num_query_heads, num_kv_heads, num_kv_heads], dim=1)

        num_queries_per_kv = num_query_heads // num_kv_heads
        if num_queries_per_kv > 1:
                # Handle MQA and GQA
                key = torch.repeat_interleave(key, num_queries_per_kv, dim=1)
                value = torch.repeat_interleave(value, num_queries_per_kv, dim=1)
        attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
        start=time.time()
        output = xops.memory_efficient_attention_forward(
                query.unsqueeze(0),
                key.unsqueeze(0),
                value.unsqueeze(0),
                attn_bias=attn_bias,
                p=0.0,
                scale=scale,

        )

        end = time.time()
        print(f"flash attention:{end - start}")
torch.cuda.empty_cache()
res =test_multi_query_kv_attention(
        num_seqs=7,
        num_heads=(40 ,40),
        head_size=64,
        dtype=torch.float16,
        seed=0,
        device="cuda:0",
)


output1=test_paged_attention(
        version="v1",
        num_seqs=7,                  #序列长度
        num_heads=(40, 40),
        head_size=64,
        use_alibi=False,
        block_size=16,
        dtype=torch.bfloat16,
        kv_cache_dtype="auto",
        seed=0,
        device="cuda:0",)
torch.cuda.empty_cache()
output2=test_paged_attention(
        version="v1",
        num_seqs=7,
        num_heads=(40, 40),
        head_size=64,
        use_alibi=False,
        block_size=16,
        dtype=torch.bfloat16,
        kv_cache_dtype="auto",
        seed=0,
        device="cuda:0",)
torch.cuda.empty_cache()
output2=test_paged_attention(
        version="v2",
        num_seqs=7,
        num_heads=(40, 40),
        head_size=64,
        use_alibi=False,
        block_size=16,
        dtype=torch.bfloat16,
        kv_cache_dtype="auto",
        seed=0,
        device="cuda:0",)
torch.cuda.empty_cache()
output2=test_paged_attention(
        version="v2",
        num_seqs=7,
        num_heads=(40, 40),
        head_size=64,
        use_alibi=False,
        block_size=16,
        dtype=torch.bfloat16,
        kv_cache_dtype="auto",
        seed=0,
        device="cuda:0",)
torch.cuda.empty_cache()

res =test_multi_query_kv_attention(
        num_seqs=7,
        num_heads=(40 ,40),
        head_size=64,
        dtype=torch.float16,
        seed=0,
        device="cuda:0",
)
torch.cuda.empty_cache()

res =test_multi_query_kv_attention(
        num_seqs=7,
        num_heads=(40 ,40),
        head_size=64,
        dtype=torch.float16,
        seed=0,
        device="cuda:0",
)
  1. nsight compute 分析paged attention
    test.py 文件:
import random
from typing import List, Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm.utils import get_max_shared_memory_bytes, is_hip
from vllm.utils import (create_kv_caches_with_random)

FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
print(MAX_SEQ_LEN)        #41216
NUM_BLOCKS = 4321
PARTITION_SIZE = 512
seed=0
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
torch.set_default_device("cuda:0")
import time



def test_paged_attention(
        version: str,
        num_seqs: int,
        num_heads: Tuple[int, int],
        head_size: int,
        use_alibi: bool,
        block_size: int,
        dtype: torch.dtype,
        kv_cache_dtype: str,
        seed: int,
        device: str,
) -> None:

        scale = float(1.0 / (head_size**0.5))
        num_query_heads, num_kv_heads = num_heads
        query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
        query.uniform_(-scale, scale)

        alibi_slopes = None

        print(MAX_SEQ_LEN)
        seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
        seq_lens[-1] = MAX_SEQ_LEN
        max_seq_len = max(seq_lens)
        seq_lens = torch.tensor(seq_lens, dtype=torch.int)

        # Create the block tables.
        max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
        block_tables_lst: List[List[int]] = []
        for _ in range(num_seqs):
                block_table = [
                        random.randint(0, NUM_BLOCKS - 1)
                        for _ in range(max_num_blocks_per_seq)
                ]
                block_tables_lst.append(block_table)

        block_tables = torch.tensor(block_tables_lst, dtype=torch.int)

        # # Create the KV caches.
        key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, block_size, 1,
                                                                num_kv_heads, head_size,
                                                                kv_cache_dtype, dtype, seed,
                                                                device)
        key_cache, value_cache = key_caches[0], value_caches[0]
        # Using default kv_scale
        kv_scale = 1.0

        output = torch.empty_like(query)
        if version == "v1":
                start=time.time()
                ops.paged_attention_v1(
                        output,
                        query,           #[7,40,64]
                        key_cache,       #[4321, 40, 8, 16, 8]
                        value_cache,     #[4321, 40, 64, 16]
                        num_kv_heads,    #40
                        scale,           #0.125
                        block_tables,    #[7,256]
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,     #None
                        kv_cache_dtype,
                        kv_scale,
                )
                end=time.time()
                print(f"paged attention:{end-start}")
        elif version == "v2":
                num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
                assert PARTITION_SIZE % block_size == 0
                num_seqs, num_heads, head_size = output.shape
                tmp_output = torch.empty(
                        size=(num_seqs, num_heads, num_partitions, head_size),
                        dtype=output.dtype,
                )
                exp_sums = torch.empty(
                        size=(num_seqs, num_heads, num_partitions),
                        dtype=torch.float32,
                )
                max_logits = torch.empty_like(exp_sums)
                start = time.time()
                ops.paged_attention_v2(
                        output,
                        exp_sums,
                        max_logits,
                        tmp_output,
                        query,
                        key_cache,
                        value_cache,
                        num_kv_heads,
                        scale,
                        block_tables,
                        seq_lens,
                        block_size,
                        max_seq_len,
                        alibi_slopes,
                        kv_cache_dtype,
                        kv_scale,
                )
                end = time.time()
                print(f"paged attention v2:{end - start}")
        else:
                raise AssertionError(f"Unknown version: {version}")
        print(torch.cuda.max_memory_allocated("cuda:0") / 1024 ** 2)  # 192M
        print("=============================")
        print(torch.cuda.memory_allocated("cuda:0") / 1024 ** 2)  # 128M
        return output
output1=test_paged_attention(
        version="v1",
        num_seqs=7,                  #序列长度
        num_heads=(40, 40),
        head_size=64,
        use_alibi=False,
        block_size=16,
        dtype=torch.bfloat16,
        kv_cache_dtype="auto",
        seed=0,
        device="cuda:0",)

torch.cuda.empty_cache()

测试:>> /home/vllm-main/tests/kernels# ncu --set full -o paged_atten python test.py
客户端加载分析:
1. 屋檐模型分析在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
分析结论:
1. paged attention v1 计算强度4.66<内存强度15.59,属于内存访问密集任务。
2. paged attention v1 的带宽L1、L2、DRAM 分别为6.43、17.03、15.59。
3. paged attention v1 位于带宽瓶颈区,模型的计算强度<平台的计算强度。模型的吞吐量<平台的吞吐量。
4. paged attention v2 计算强度26.04<75.58, 仍然属于访问密集任务。
5. paged attention v2 的带宽L1、L2、DRAM 分别为20.91、76.81、75.58。
6. v2 比v1 计算强度与吞吐量均有提高,执行速度降低。
paged attention v1 优化方向:提高计算强度、提高模型的吞吐速度。
paged attention v2 优化方向:仍可提高。
待续

  1. 带宽分析在这里插入图片描述
    在这里插入图片描述
    L1/TEX Hit Rate:L1 命中率
    L2 Hit Rate:L2 命中率
    L2 Compression Success Rate: L2 压缩成功率
    Mem Busy:内存利用率
    Mem Pipes Busy: 内存管道忙碌状态
    L2 Compression Ratio: L2压缩率
    memory chart: 内存图表直观的显示了数据访问流向以及命中率。从左向右看,在内核中进行计算。

    1. 优化方向1: 从L1TEX到L2的存储器的存储器访问模式不是最佳的。对L2的L1TEX请求的粒度是128字节的缓存行。即每个L2请求有4个连续的32字节扇区。然而,这个内核平均只访问每个缓存行可能的4个扇区中的1.5个扇区。检查源计数器部分是否有未加密的存储,并尽量减少每个 内存请求需要访问的缓存行数。
    2. 优化方向2:从设备存储器加载的存储器访问模式导致从DRAM读取55330588个扇区,这是导致L2缓存中丢失的55318297个扇区的1.0倍。L2中读取未命中的DRAM读取粒度为64字节,即L2缓存行的下半部分或上半部分。尝试更改访问模式,以利用DRAM读取请求返回的两个扇区,从而优化DRAM吞吐量的使用。对于跨步内存读取,避免64字节或更大的跨步,以避免将未使用的扇区从DRAM移动到L2。
  2. 内存分析
    在这里插入图片描述
    1 wrap occupancy(或通常称为Occupancy)是一个关键的性能指标,用于衡量GPU利用率和并行度。具体来说,Occupancy是指在一个流多处理器(Streaming Multiprocessor,简称SM)中,处于活跃状态(即非等待状态)的warp数量与SM支持的最大活跃warp数量的比值。这里的“warp”是GPU执行线程的基本单位,通常包含多个线程(如32个线程),这些线程在GPU上并行执行相同的指令。
    2 Occupancy直接反映了GPU的硬件利用率。高Occupancy通常意味着GPU的并行处理资源得到了有效利用,可以执行更多的计算任务。
    3. 虽然高Occupancy不总是代表高性能(因为还可能受到其他因素的影响,如内存带宽、缓存效率等),但低Occupancy通常会降低性能,因为它限制了GPU隐藏延迟的能力,导致计算单元在等待数据时处于空闲状态。
    4. Occupancy 是指一个 SM 中 active warps 数量和最大可能的 active warps 数量的比值。

影响Occupancy的因素

  1. 寄存器使用:每个线程使用的寄存器数量会影响可并发执行的warp数量。如果每个线程使用的寄存器过多,那么SM上能够同时执行的warp数量就会减少,从而降低Occupancy。
  2. 共享内存使用:每个block使用的共享内存也会影响Occupancy。共享内存是有限的资源,如果每个block使用的共享内存过多,那么能够同时执行的block数量就会减少,进而降低Occupancy。
  3. block和warp的配置:block的大小(即每个block中的线程数)和warp的大小(即每个warp中的线程数)也会影响Occupancy。通过合理配置这些参数,可以在不同场景下优化Occupancy。

如何优化Occupancy

  1. 减少寄存器使用:通过优化代码,减少每个线程使用的寄存器数量,可以增加同时执行的warp数量,从而提高Occupancy。
  2. 合理配置block和warp:根据具体的应用场景和GPU的硬件特性,合理配置block的大小和warp的大小,以达到最佳的Occupancy。
  3. 优化内存访问:通过优化数据布局和访问模式,减少内存访问延迟和冲突,可以提高GPU的缓存效率和内存带宽利用率,从而间接提高Occupancy。
  4. 上图分析:attention v1 与attention v2 采用相同的数量的寄存器、Block块,但v1 占用更多的共享内存,从而导致每个GPU 利用率降低。

7 profiler 分析paged attention

with torch.profiler.profile(on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/example'),
        record_shapes=True,
        profile_memory=True,
        with_stack=True):
        output1=test_paged_attention(
                version="v1",
                num_seqs=7,                  #序列长度
                num_heads=(40, 40),
                head_size=64,
                use_alibi=False,
                block_size=16,
                dtype=torch.bfloat16,
                kv_cache_dtype="auto",
                seed=0,
                device="cuda:0",)
torch.cuda.empty_cache()

加载模型>>(base) D:\cnki_1\model_predict>tensorboard --logdir=D:/log
在这里插入图片描述

  • 10
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值