Load LoRA XXX.safetensors to diffusers (StableDiffusionPipeline/StableDiffusionControlNetPipeline)

该函数从安全张量中加载LoRA权重,然后更新Diffusers模型中的text_encoder和unet层。它遍历state_dict,根据键的结构找到对应模型层,并应用权重更新。权重更新涉及到weight_up、weight_down和alpha,最后将这些权重适当地加到当前层的权重数据上。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype):
    LORA_PREFIX_UNET = "lora_unet"
    LORA_PREFIX_TEXT_ENCODER = "lora_te"
    # load LoRA weight from .safetensors
    state_dict = load_file(checkpoint_path, device=device)

    updates = defaultdict(dict)
    for key, value in state_dict.items():
        # it is suggested to print out the key, it usually will be something like below
        # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"

        layer, elem = key.split('.', 1)
        updates[layer][elem] = value

    # directly update weight in diffusers model
    for layer, elems in updates.items():

        if "text" in layer:
            layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
            curr_layer = pipeline.text_encoder
        else:
            layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
            curr_layer = pipeline.unet

        # find the target layer
        temp_name = layer_infos.pop(0)
        while len(layer_infos) > -1:
            try:
                curr_layer = curr_layer.__getattr__(temp_name)
                if len(layer_infos) > 0:
                    temp_name = layer_infos.pop(0)
                elif len(layer_infos) == 0:
                    break
            except Exception:
                if len(temp_name) > 0:
                    temp_name += "_" + layer_infos.pop(0)
                else:
                    temp_name = layer_infos.pop(0)

        # get elements for this layer
        weight_up = elems['lora_up.weight'].to(dtype)
        weight_down = elems['lora_down.weight'].to(dtype)
        alpha = elems['alpha']
        if alpha:
            alpha = alpha.item() / weight_up.shape[1]
        else:
            alpha = 1.0

        curr_layer.weight.data = curr_layer.weight.data.to(device) 
        # update weight
        if len(weight_up.shape) == 4:
            curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
        else:
            curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)

    return pipeline

Example usage:

pipe = load_lora_weights(pipe, lora_path, 1.0, 'cuda', torch.float16)

### 如何用 Python 加载或操作 LoRA F_1 safetensors 格式的文件 `safetensors` 是一种高效且安全的模型存储格式,由 Hugging Face 推出,专注于提升模型的安全性和加载性能[^1]。为了加载 `.safetensors` 文件中的数据,可以使用 `safetensors` 库提供的 API。 以下是具体方法: #### 安装依赖库 首先需要安装 `safetensors` 库,可以通过以下命令完成: ```bash pip install safetensors ``` #### 使用 Python 加载 `.safetensors` 文件 下面是一个完整的代码示例,展示如何加载并操作名为 `LoRA_F_1.safetensors` 的文件: ```python import torch from safetensors.torch import load_file, save_file # 加载 .safetensors 文件 file_path = "LoRA_F_1.safetensors" weights = load_file(file_path) # 查看权重键名 print(list(weights.keys())) # 如果需要提取特定层的权重 if 'layer_name' in weights: layer_weight = weights['layer_name'] print(f"Layer weight shape: {layer_weight.shape}") # 修改权重(如果需要) modified_weights = {} for key, value in weights.items(): modified_value = value * 2 # 假设我们想将所有权重乘以2作为修改 modified_weights[key] = modified_value # 将修改后的权重保存回新的 .safetensors 文件 output_file_path = "Modified_LoRA_F_1.safetensors" save_file(modified_weights, output_file_path) ``` 上述代码实现了以下几个功能: 1. **加载**:通过调用 `load_file()` 函数读取 `.safetensors` 文件的内容,并将其解析为字典形式的数据结构,其中键是权重名称,值是对应的张量。 2. **查看权重信息**:打印出权重的键名列表以便了解模型内部结构。 3. **访问特定权重**:可以根据键名获取某一层的具体权重。 4. **修改权重**:对原始权重进行简单变换(如缩放),并将结果存入新字典中。 5. **保存**:利用 `save_file()` 函数将修改后的权重重新保存到一个新的 `.safetensors` 文件中。 需要注意的是,`.safetensors` 文件只包含模型的权重参数而无任何执行逻辑,因此在实际应用时可能还需要结合其他框架(例如 PyTorch 或 TensorFlow)来构建完整的推理流程。 ---
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值