MixtralForCausalLM DeepSpeed Inference节约HOST内存【最新的方案】

本文演示了MixtralForCausalLM DeepSpeed Inference如果节约HOST内存
方法:每个rank分别保存,并且使用accelerate的init_empty_weights
增加的功能:

  • safetensors分块的存储与加载
  • 解决register_buffer persistent=False,参数初始化的问题

一.效果

运行方式HOST内存占用备注
单卡推理13198 MB
DS 4TP13246 MB/GPU
DS 4TP 优化内存占用后369 MB/GPU直接加载到设备,更节约HOST内存

二.特别说明

  • 1.MixtralRotaryEmbedding中self.register_buffer(“sin_cached”, emb.sin().to(dtype), persistent=False)
    因为persistent为False。所以不会保存到state_dict中,module.to_empty(device)也不会保留它的值
    只能在模型初始化之后保存出来,之后engine.moudle加载完权值之后再把这个buffer替换进去

三.测试步骤

1.创建Mixtral-8x7B配置文件(简化了)

mkdir skip_init_demo
cd skip_init_demo
tee ./config.json <<-'EOF'
{
  "architectures": [
    "MixtralForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "max_position_embeddings": 1024,
  "model_type": "mixtral",
  "num_attention_heads": 32,
  "num_experts_per_tok": 2,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "num_local_experts": 8,
  "output_router_logits": false,
  "rms_norm_eps": 1e-05,
  "rope_theta": 1000000.0,
  "router_aux_loss_coef": 0.02,
  "sliding_window": 128,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.36.0.dev0",
  "use_cache": true,
  "vocab_size": 32000
}
EOF

2.生成随机模型,运行cpu float32推理,输出结果

rm -rf Mixtral-8x7B
tee gen_model.py <<-'EOF'
import torch
import os
import time
def main():
    torch.manual_seed(1)
    from transformers import MixtralForCausalLM, MixtralConfig
    config=MixtralConfig.from_pretrained("./config.json")
    model = MixtralForCausalLM(config).half()    
    model.eval()
    model.save_pretrained("./Mixtral-8x7B",safe_serialization=True)
    torch.manual_seed(2)
    input_tokens=torch.randint(0,32000,(1,128))
    model=model.float()
    output=model(input_tokens)
    output=output.logits.detach().reshape(-1).cpu().numpy()[:8]
    print(output)

if __name__ == "__main__":
    main()
EOF
python gen_model.py
du Mixtral-8x7B -lh

输出

6.3G    Mixtral-8x7B

[-0.9623295  -0.36580455  0.767425    1.7021806  -0.17950581  0.36059803
 -0.49157432 -0.58618194]

3.加载模型,cuda 单卡推理

tee open_model.py <<-'EOF'
import torch
import os
import psutil
import time
from transformers.modeling_utils import load_sharded_checkpoint,load_state_dict
import json
from safetensors import safe_open

def get_mem_info():
    pid = os.getpid()
    current_process = psutil.Process(pid)
    memory_info = current_process.memory_info()
    print(f"RSS: {memory_info.rss / (1024 * 1024):.2f}MB VMS:{memory_info.vms / (1024 * 1024):.2f}MB")

def main():
    from transformers import MixtralForCausalLM, MixtralConfig
    get_mem_info()
    config=MixtralConfig.from_pretrained("./config.json")
    model = MixtralForCausalLM(config).half()
    get_mem_info()

    with open("Mixtral-8x7B/model.safetensors.index.json", "r") as file:
        index_data = json.load(file)

    weight_files = index_data.get('weight_map', [])
    state_dict = {}
    for k,v in weight_files.items():
        weights_path = os.path.join("Mixtral-8x7B", v)
        with safe_open(weights_path, framework="pt") as f:
            for k in f.keys():
                state_dict[k] = f.get_tensor(k)       
        
    model.load_state_dict(state_dict, strict=True)
    get_mem_info()

    model=model.to("cuda:0")
    torch.manual_seed(2)
    input_tokens=torch.randint(0,32000,(1,128)).to("cuda:0")
    output=model(input_tokens)
    output=output.logits.detach().reshape(-1).cpu().numpy()[:8]
    print(output)
    
if __name__ == "__main__":
    main()
EOF
python open_model.py

输出:

RSS: 251.70MB VMS:3292.21MB
RSS: 6697.91MB VMS:13695.17MB
RSS: 13198.57MB VMS:26385.02MB

[-0.9633789  -0.36450195  0.76708984  1.703125   -0.1772461   0.3581543
 -0.48901367 -0.5888672 ]

4.DS 4 TP cuda 推理

tee open_model.py <<-'EOF'
import torch
import os
import psutil
import time
from transformers.modeling_utils import load_sharded_checkpoint,load_state_dict
import deepspeed
from deepspeed.accelerator import get_accelerator
import json
from safetensors import safe_open

deepspeed.init_distributed(dist_backend='nccl')
world_size = torch.distributed.get_world_size()
local_rank=int(os.environ['LOCAL_RANK'])
rank=torch.distributed.get_rank()

def get_mem_info(prefix):
    pid = os.getpid()
    current_process = psutil.Process(pid)
    memory_info = current_process.memory_info()
    print(f"{prefix} RANK:{os.environ['LOCAL_RANK']} RSS: {memory_info.rss / (1024 * 1024):.2f}MB VMS:{memory_info.vms / (1024 * 1024):.2f}MB")

def main():
    torch.set_num_threads(1)
    from transformers import MixtralForCausalLM, MixtralConfig
    get_mem_info("Init")
    config=MixtralConfig.from_pretrained("./config.json")
    model = MixtralForCausalLM(config).half()
    get_mem_info("ModelCreate")
    print("-----------------------")

    with open("Mixtral-8x7B/model.safetensors.index.json", "r") as file:
        index_data = json.load(file)

    weight_files = index_data.get('weight_map', [])
    state_dict = {}
    for k,v in weight_files.items():
        weights_path = os.path.join("Mixtral-8x7B", v)
        with safe_open(weights_path, framework="pt") as f:
            for k in f.keys():
                state_dict[k] = f.get_tensor(k)

    model.load_state_dict(state_dict, strict=True)
    get_mem_info("LoadState")
    print("-----------------------")
    engine = deepspeed.init_inference(model,
                                      tensor_parallel={"tp_size": world_size},
                                      dtype=torch.float16,
                                      replace_with_kernel_inject=False)
    device=get_accelerator().current_device_name()
    print("device:",device)
    torch.manual_seed(2)
    input_tokens=torch.randint(0,32000,(1,128)).to(device)
    output=engine(input_tokens)
    output=output.logits.detach().reshape(-1).cpu().numpy()[:8]
    if rank==0:
        print(output)
    
if __name__ == "__main__":
    main()
EOF
deepspeed --num_gpus=4 open_model.py

输出:


Init RANK:1 RSS: 270.02MB VMS:3414.44MB
Init RANK:3 RSS: 270.43MB VMS:3414.45MB
Init RANK:2 RSS: 270.22MB VMS:3414.45MB
Init RANK:0 RSS: 270.38MB VMS:3486.45MB

ModelCreate RANK:0 RSS: 6757.33MB VMS:9965.12MB
ModelCreate RANK:3 RSS: 6727.30MB VMS:9862.06MB
ModelCreate RANK:2 RSS: 6757.18MB VMS:9893.12MB
ModelCreate RANK:1 RSS: 6756.99MB VMS:9893.12MB

LoadState RANK:2 RSS: 13248.96MB VMS:22772.97MB
LoadState RANK:0 RSS: 13245.91MB VMS:22616.97MB
LoadState RANK:3 RSS: 13233.00MB VMS:22490.91MB
LoadState RANK:1 RSS: 13246.22MB VMS:23240.97MB

[-0.96240234 -0.36547852  0.7680664   1.703125   -0.17382812  0.359375
 -0.49169922 -0.5883789 ]

5.分别保存DS 4TP每个rank上engine.module的权值

tee open_model.py <<-'EOF'
import torch
import os
import psutil
import time
from transformers.modeling_utils import load_sharded_checkpoint,load_state_dict
import deepspeed
from deepspeed.accelerator import get_accelerator
import json
from safetensors import safe_open
from safetensors.torch import save_file, load_file

deepspeed.init_distributed(dist_backend='nccl')
world_size = torch.distributed.get_world_size()
local_rank=int(os.environ['LOCAL_RANK'])
rank=torch.distributed.get_rank()

def get_mem_info(prefix):
    pid = os.getpid()
    current_process = psutil.Process(pid)
    memory_info = current_process.memory_info()
    print(f"{prefix} RANK:{os.environ['LOCAL_RANK']} RSS: {memory_info.rss / (1024 * 1024):.2f}MB VMS:{memory_info.vms / (1024 * 1024):.2f}MB")

def save_state_dict(state_dict,save_dir):
    max_bytes_per_file = 1 * 1024 * 1024 * 1024  # 1GB
    # 计算每个 tensor 的大小并拆分 state_dict
    split_state_dicts = []
    current_state_dict = {}
    current_size = 0
    for param_name, param_tensor in state_dict.items():
        tensor_size = param_tensor.element_size() * param_tensor.nelement()
        # 如果当前 tensor 超过了文件大小,先保存已有 tensors
        if current_size + tensor_size > max_bytes_per_file:
            split_state_dicts.append(current_state_dict)
            current_state_dict = {}
            current_size = 0
        current_state_dict[param_name] = param_tensor
        current_size += tensor_size

    # 添加最后一个 state_dict
    if current_state_dict:
        split_state_dicts.append(current_state_dict)

    # 保存拆分后的 state_dicts 并生成索引文件
    os.makedirs(save_dir, exist_ok=True)
    index = {
        "metadata": {
            "total_parts": len(split_state_dicts)
        },
        "weight_map": []
    }
    for i, sd in enumerate(split_state_dicts):
        part_file = os.path.join(save_dir, f"model_part_{i}.safetensors")
        save_file(sd, part_file)
        index["weight_map"].append(f"model_part_{i}.safetensors")

    # 保存索引文件
    index_file = os.path.join(save_dir, "index.json")
    with open(index_file, 'w') as f:
        json.dump(index, f, indent=4)

def main():
    from transformers import MixtralForCausalLM, MixtralConfig
    get_mem_info("Init")
    config=MixtralConfig.from_pretrained("./config.json")
    model = MixtralForCausalLM(config).half()
    get_mem_info("ModelCreate")
    print("-----------------------")
    with open("Mixtral-8x7B/model.safetensors.index.json", "r") as file:
        index_data = json.load(file)

    weight_files = index_data.get('weight_map', [])
    state_dict = {}
    for k,v in weight_files.items():
        weights_path = os.path.join("Mixtral-8x7B", v)
        with safe_open(weights_path, framework="pt") as f:
            for k in f.keys():
                state_dict[k] = f.get_tensor(k)
                
    model.load_state_dict(state_dict, strict=True)
    get_mem_info("LoadState")
    print("-----------------------")
    engine = deepspeed.init_inference(model,
                                      tensor_parallel={"tp_size": world_size},
                                      dtype=torch.float16,
                                      replace_with_kernel_inject=False)
    save_state_dict(engine.module.state_dict(), f"./Mixtral-8x7B-{local_rank}")
if __name__ == "__main__":
    main()
EOF
deepspeed --num_gpus=4 open_model.py
du Mixtral-8x7B-* -lh

输出

1.7G    Mixtral-8x7B-0
1.7G    Mixtral-8x7B-1
1.7G    Mixtral-8x7B-2
1.7G    Mixtral-8x7B-3

6.DS 4TP推理,init_empty_weights初始化模型,每个rank加载自己engine.module的权值

tee open_model.py <<-'EOF'
import torch
import os
import psutil
import time
from accelerate import init_empty_weights
from transformers.modeling_utils import load_sharded_checkpoint,load_state_dict
import deepspeed
from deepspeed.accelerator import get_accelerator
import json
from safetensors import safe_open
from safetensors.torch import save_file, load_file

deepspeed.init_distributed(dist_backend='nccl')
world_size = torch.distributed.get_world_size()
local_rank=int(os.environ['LOCAL_RANK'])
rank=torch.distributed.get_rank()

def get_mem_info(prefix):
    pid = os.getpid()
    current_process = psutil.Process(pid)
    memory_info = current_process.memory_info()
    print(f"{prefix} RANK:{os.environ['LOCAL_RANK']} RSS: {memory_info.rss / (1024 * 1024):.2f}MB VMS:{memory_info.vms / (1024 * 1024):.2f}MB")

def my_load_state_dict(model,save_dir):
    index_file = os.path.join(save_dir, "index.json")
    with open(index_file, "r") as file:
        index_data = json.load(file)

    weight_files = index_data.get('weight_map', [])
    state_dict = {}
    for v in weight_files:
        weights_path = os.path.join(save_dir, v)
        with safe_open(weights_path, framework="pt") as f:
            for k in f.keys():
                state_dict[k] = f.get_tensor(k)

    model.load_state_dict(state_dict, strict=True)

def main():
    from transformers import MixtralForCausalLM, MixtralConfig
    get_mem_info("Init")
    config=MixtralConfig.from_pretrained("./config.json")
    with init_empty_weights():
        model = MixtralForCausalLM(config).half()
    get_mem_info("ModelCreate")
    print("-----------------------")
    buffer_dict = {}
    for name, param in model.named_buffers():
        buffer_dict[name] = param
    
    engine = deepspeed.init_inference(model,
                                      tensor_parallel={"tp_size": world_size},
                                      dtype=torch.float16,
                                      replace_with_kernel_inject=False)
    my_load_state_dict(engine.module,f"./Mixtral-8x7B-{local_rank}")

    for name, param in engine.module.named_buffers():
        param.copy_(buffer_dict[name])
    
    get_mem_info("LoadState")
    device=get_accelerator().current_device_name()
    torch.manual_seed(2)
    input_tokens=torch.randint(0,32000,(1,128)).to(device)
    output=engine(input_tokens)
    output=output.logits.detach().reshape(-1).cpu().numpy()[:8]
    if rank==0:
        print(output)
if __name__ == "__main__":
    main()
EOF
deepspeed --num_gpus=4 open_model.py

输出


Init RANK:1 RSS: 269.73MB VMS:3382.40MB
Init RANK:2 RSS: 269.45MB VMS:3382.39MB
Init RANK:3 RSS: 269.86MB VMS:3382.39MB
Init RANK:0 RSS: 269.96MB VMS:3454.39MB

ModelCreate RANK:1 RSS: 300.44MB VMS:17064.71MB
ModelCreate RANK:0 RSS: 297.03MB VMS:17136.70MB
ModelCreate RANK:2 RSS: 299.22MB VMS:17064.70MB
ModelCreate RANK:3 RSS: 300.66MB VMS:17065.70MB

LoadState RANK:0 RSS: 366.28MB VMS:20159.03MB
LoadState RANK:3 RSS: 369.87MB VMS:20152.03MB
LoadState RANK:2 RSS: 368.37MB VMS:20151.02MB
LoadState RANK:1 RSS: 369.16MB VMS:20087.04MB

[-0.96240234 -0.36547852  0.7680664   1.703125   -0.17382812  0.359375
 -0.49169922 -0.5883789 ]

  • 12
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Hi20240217

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值