图像生成:SD lora加载代码详解与实现

本文详细介绍了在SDWebUI中加载和融合Lora模型的过程,包括模型架构加载、safetensors权重处理、名称匹配以及权重融合的具体步骤。
摘要由CSDN通过智能技术生成


前言

SD中lora的加载相信都不陌生,但是大家大多数都是利用SD webUI加载lora,本文主要梳理一下SD webUI中lora加载的代码逻辑。关于lora的原理,可以参考我之前的博客——图像生成:SD LoRA模型详解


一、SD模型介绍

SD model结构一般分为几个部分,如下:
在这里插入图片描述

SD webui使用pytorch lightning搭建,了解pl的同学可能知道,模型的相关配置一般都写在yaml文件中,因此其实可以根据yaml文件来判断模型的基本结构,类似如下:

      unet_config:
        target: ldm.modules.diffusionmodules.openaimodel.UNetModel
        params:
          image_size: 32 # unused
          in_channels: 4
          out_channels: 4
          model_channels: 320
          attention_resolutions: [ 4, 2, 1 ]
          num_res_blocks: 2
          channel_mult: [ 1, 2, 4, 4 ]
          num_heads: 8
          use_spatial_transformer: True
          transformer_depth: 1
          context_dim: 768
          use_checkpoint: True
          legacy: False
  
      first_stage_config:
        target: ldm.models.autoencoder.AutoencoderKL
        params:
          embed_dim: 4
          monitor: val/rec_loss
          ddconfig:
            double_z: true
            z_channels: 4
            resolution: 256
            in_channels: 3
            out_ch: 3
            ch: 128
            ch_mult:
            - 1
            - 2
            - 4
            - 4
            num_res_blocks: 2
            attn_resolutions: []
            dropout: 0.0
          lossconfig:
            target: torch.nn.Identity
  
      cond_stage_config:
        target: ldm.modules.encoders.modules.FrozenCLIPEmbedder

二、模型加载

因为整个模型分为VAE,CLIP,Unet等多个部分,lora模型一般是对这几个部分进行权重修改,有的可能只修改了Unet,有的可能多个部分都修改了,这个是由当时lora的训练师冻结的模块决定的。本文接下来主要以Unet部分的加载和lora修改为例来介绍,其它模块类似。

1. 模型架构加载

首先根据yaml文件加载模型架构(这里我只保留了Unet的配置文件,然后进行加载)

unet_config_path = '/data/wangyx/工程处理/sd/unet.yaml'
diff_model_config = OmegaConf.load(unet_config_path)  
unet_config = diff_model_config.model.unet_config
diffusion_model = instantiate_from_config(unet_config)  

2. safetensors权重加载

这里以比较火的chilloutmix作为example 。( ⁼̴̀ .̫ ⁼̴ )✧

#可视化权重结构
import torch
from safetensors.torch import load_file, save_file
from safetensors import safe_open
model_path = '/stable-diffusion-webui/stable-diffusion-webui/models/Stable-diffusion/chilloutmix_NiPrunedFp32Fix.safetensors'
tensors = {}
with safe_open(model_path, framework="pt", device='cpu') as f:
    for k in f.keys():
        tensors[k] = f.get_tensor(k)

获取权重后即可把对应权重文件加载到模型结构中。

3. lora权重加载

lora的权重加载直接用safetensor加载即可

lora_path = '**/sdxl_lcm_lora.safetensors'
pl_sd = safetensors.torch.load_file(lora_path) 

三、Name匹配

打印出lora权重名字和模型每一层的名字后其实可以发现,其实都是不对应的,因此需要手动将他们匹配起来,在SD webUI中的Lora中部分使用下面这样一个函数完成权重匹配

def convert_diffusers_name_to_compvis(key, is_sd2):
    def match(match_list, regex_text):
        regex = re_compiled.get(regex_text)
        if regex is None:
            regex = re.compile(regex_text)
            re_compiled[regex_text] = regex

        r = re.match(regex, key)
        if not r:
            return False

        match_list.clear()
        match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])
        return True

    m = []

    if match(m, r"lora_unet_conv_in(.*)"):
        return f'diffusion_model_input_blocks_0_0{m[0]}'

    if match(m, r"lora_unet_conv_out(.*)"):
        return f'diffusion_model_out_2{m[0]}'

    if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"):
        return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}"

    if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
        suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
        return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"

    if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
        suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
        return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"

    if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
        suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
        return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"

    if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
        return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"

    if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
        return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"

    if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
        if is_sd2:
            if 'mlp_fc1' in m[1]:
                return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
            elif 'mlp_fc2' in m[1]:
                return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
            else:
                return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"

        return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"

    if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"):
        if 'mlp_fc1' in m[1]:
            return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
        elif 'mlp_fc2' in m[1]:
            return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
        else:
            return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"

    return key


key1 = 'lora_unet_down_blocks_0_downsamplers_0_conv'  #.alpha
new_key1 = convert_diffusers_name_to_compvis(key1, is_sd2=True)
print(new_key1)

通过这样一个函数即可把lora权重的名字替换成和模型名字一样的。


另外同时对SD原模型的层名字进行修改,并存在一个字典中:

def assign_network_names_to_compvis_modules(sd_model):
    network_layer_mapping = {}
    for name, module in sd_model.named_modules():
        network_name = name.replace(".", "_")
        network_layer_mapping[network_name] = module
        module.network_layer_name = network_name
    sd_model.network_layer_mapping = network_layer_mapping

这样子lora和模型的名字就完成对应了。

四、权重融合

完成权重匹配后,就可以进行权重融合了,这里我将SD wenUI中的代码摘了一部分出来进行实现,从而更好理解原理。

1、构建net类

这里我调用了webUI中的net类,用于后续赋予相关属性

network_on_disk = NetworkOnDisk('name', '.pth')
net = Network('name', network_on_disk)
#这部分是建立一个空壳,用于后续操作

2、匹配lora weight和model weight

#创建nametuple保存权重
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
matched_networks = {}

for key_network, weight in pl_sd.items():  #循环lora的每一项
    key_network_without_network_parts, network_part = key_network.split(".", 1)
    #如果是SDXL,那么isSD2需要选择为True
    fkey = convert_diffusers_name_to_compvis(key_network_without_network_parts, True) 
    
    key = fkey[16:]
    '''
    正常模型架构是model.diffusion_model, model.first_stage_model
    但是现在我们只加载了unet部分所以先去掉前半部分
    '''
    sd_module = diffusion_model.network_layer_mapping.get(key, None)
    #获取修改名字后对应的module
	
	if key not in matched_networks and sd_module is not None:  
        matched_networks[key] = NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module)

    if sd_module is not None:
        matched_networks[key].w[network_part] = weight
    

通过上述代码就可以将匹配的Unet层权重和lora权重放在一个NetworkWeights组了,用于后续融合。代码中w存的是lora的权重,sd module存的是对应unet中的结构和权重。

3、基于lora权重创建lora模块

上面构建了matched_networks字典,每个key下对应了一组匹配好的lora权重和模型模块,接下来就是基于lora权重创建lora模块了,我们第一步创建了一个net空壳类,在这里赋予net.modules属性,并将创建好的lora模块赋予该属性。

for key, weights in matched_networks.items():
    print(key)
    print(weights)
    net_module = None
    for nettype in module_types:
        net_module = nettype.create_module(net, weights)
        if net_module is not None:
            break
    net.modules[key] = net_module

这里的create_module来自于源码中的Lora/network_lora.py, 该函数主要是创建结构,并赋予权重,最后构建一个完整的lora模块。

4、权重融合

接下来就可以完成权重融合了,这里我以某一层为例

network_layer_name = 'output_blocks_11_1_transformer_blocks_0_attn1_to_k'
module = net.modules.get(network_layer_name, None)
print('get模块', module)
fb = matched_networks[network_layer_name].sd_module  
fb_weight = fb.weight
with torch.no_grad():
    updown, ex_bias = module.calc_updown(fb_weight) 
    if len(fb_weight) == 4 and fb_weight.shape[1] == 9:
        # inpainting model. zero pad updown to make channel[1]  4 to 9
        updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))
    fb_weight += updown

updown计算好的lora权重,fb_weight是原模块的权重,相加即可完成融合。
若要实现所有的权重融合,循环matched_networks中每一个key,然后执行上述操作,最后进行权重替换即可。

五、整体pipeline

最后来整体梳理一下:
1)加载sd模型结构,加载权重
2)加载lora模型
3)进行name匹配,找到互相对应的层
4)将互相对应的sd_module和lora权重放在一组
5)基于lora权重创建lora模块
6)完成lora的计算并融合

总结

整体lora融合的流程其实并不复杂,主要在于匹配和计算融合,SD webUI可能由于集成度较高所以看起来代码稍显复杂,感兴趣其实可以照着这个流程进一步简化。

  • 30
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是基于Qt实现LoRa通信代码示例,仅供参考: ```cpp #include <QCoreApplication> #include <QtSerialPort/QSerialPort> #include <QtSerialPort/QSerialPortInfo> #include <QDebug> #define LoRa_MODE_SLEEP 0x80 #define LoRa_MODE_STDBY 0x81 #define LoRa_MODE_TX 0x83 #define LoRa_MODE_RX_CONT 0x85 #define LoRa_MODE_RX_SINGLE 0x86 QSerialPort *serialPort; quint8 readData[256]; quint8 writeData[256]; void initSerialPort() { QList<QSerialPortInfo> portInfos = QSerialPortInfo::availablePorts(); qDebug() << "Found " << portInfos.size() << " serial ports."; foreach (QSerialPortInfo portInfo, portInfos) { qDebug() << "PortName:" << portInfo.portName(); qDebug() << "Description:" << portInfo.description(); qDebug() << "Manufacturer:" << portInfo.manufacturer(); qDebug() << "SerialNumber:" << portInfo.serialNumber(); } serialPort = new QSerialPort("COM3"); //修改为实际串口号 if(serialPort->open(QIODevice::ReadWrite)) { qDebug() << "Serial port open success!"; serialPort->setBaudRate(QSerialPort::Baud115200); serialPort->setDataBits(QSerialPort::Data8); serialPort->setStopBits(QSerialPort::OneStop); serialPort->setParity(QSerialPort::NoParity); serialPort->setFlowControl(QSerialPort::NoFlowControl); } else { qDebug() << "Serial port open failed!"; } } int main(int argc, char *argv[]) { QCoreApplication a(argc, argv); initSerialPort(); //进入睡眠模式 writeData[0] = LoRa_MODE_SLEEP; serialPort->write((const char*)writeData, 1); //设置频率、扩频因子、带宽等参数 //... //进入接收模式 writeData[0] = LoRa_MODE_RX_CONT; serialPort->write((const char*)writeData, 1); while (1) { qint64 readSize = serialPort->read((char*)readData, sizeof(readData)); if(readSize > 0) { qDebug() << "Read:" << readSize << "bytes"; for(int i=0; i<readSize; i++) { qDebug() << QString::number(readData[i], 16).toUpper().rightJustified(2, '0'); } } } return a.exec(); } ``` 其,需要根据实际情况修改串口号、频率、扩频因子、带宽等参数。同时注意LoRa通信的数据格式和解析方式。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值