StableDiffusion LoRA 原理与代码详解

LoRA原理

LoRA(Learnable Re-Weighting),是一种重加权模型。LORA模型将神经网络中的每一层看做是一个可加权的特征提取器,每一层的权重决定了它对模型输出的影响。通过对已有的SD模型的部分权重进行调整,从而实现对生图效果的改善。(大部分LoRA模型是对Transformer中的注意力权重Linear层进行了调整,也有部分对Conv2D的卷积核权重进行微调)

Linear LoRA

LoRA从技术角度来讲很简单,基本流程如下图:蓝色部分表示原来的预训练权重,橙色部分则是lora需要训练的权重A和B。
在这里插入图片描述

  • 训练阶段:上图中的A和B是可训练的权重,在训练阶段,蓝色部分冻结,训练A和B,最后保存权重也仅保存A和B的相关参数。基本步骤如下:
    在这里插入图片描述
  • 推理阶段:正常使用W = W0+BA来更新模型权重。

普通的Linear LoRA只考虑了Linear层的weight,B矩阵的维度数量len(weight_up.shape)等于 A矩阵的维度数量len(weight_down.shape) 的。B的shape是[in_dim, rank],A的形状是[rank, out_dim],乘完就是[in_dim, out_dim]

		if isinstance(module, nn.Linear):
            assert len(self._up_weight.shape) == len(self._down_weight.shape) == 2

            in_dim = module.in_features
            out_dim = module.out_features
            self._lora_down = nn.Linear(in_dim, self._r, bias=False)
            self._lora_up = nn.Linear(self._r, out_dim, bias=False)


        self._lora_down.weight = nn.Parameter(self._down_weight)
        self._lora_up.weight = nn.Parameter(self._up_weight)

Conv2d LoRA

卷积层和全连接层是两种不同的操作。在卷积层的LoRA改造,主要是对卷积核的权重矩阵做改造。明白了这一点,其实卷积层的改造思路跟全连接基本是一致的。

Conv2d的B矩阵的维度数量len(weight_up.shape)不等于 A矩阵的维度数量len(weight_down.shape) 的。融合时需要保证最后2个维度不变,对前面2个维度做矩阵的转置乘法。B的shape是[in_dim, rank, kernel=(h, w)],A的形状是[rank, out_dim, (1, 1)],乘完就是[in_dim, out_dim, kernel=(h, w)],即[in_dim, out_dim, h, w]

这里的A和B的shape设计是有原因的,因为BA最后的形状要和原始模型中卷积核的权重矩阵(self.weight)一致,所以要根据self.weight的shape来设计:在这里插入图片描述

	 if isinstance(module, nn.Conv2d):
            assert len(self._up_weight.shape) == len(self._down_weight.shape) == 4

            r = self._r
            in_dim = module.in_channels
            out_dim = module.out_channels
            kernel = module.kernel_size
            stride = module.stride
            padding = module.padding

            self._lora_down = nn.Conv2d(in_dim, r, kernel, stride, padding, bias=False)
            self._lora_up = nn.Conv2d(r, out_dim, (1, 1), (1, 1), bias=False)


        self._lora_down.weight = nn.Parameter(self._down_weight)
        self._lora_up.weight = nn.Parameter(self._up_weight)

LoRA文件内容

lora模型中每层的权重包含3个部分,分别为.lora_down.weight.lora_up.weight.alpha。其中downup分别为lora模型的上下层权重分别对应了BA权重,alpha也是一个可学习的参数。lora模型每层的权重可表示为:
w = a l p h a ∗ ( d o w n M a t r i x   @   u p M a t r i x ) w = alpha∗(downMatrix \ @ \ upMatrix) w=alpha(downMatrix @ upMatrix)

以目前最流行的LCM LoRA为例进行一下可视化,key如下(部分):

lora_unet_down_blocks_0_downsamplers_0_conv.alpha
lora_unet_down_blocks_0_downsamplers_0_conv.lora_down.weight
lora_unet_down_blocks_0_downsamplers_0_conv.lora_up.weight
lora_unet_down_blocks_0_resnets_0_conv1.alpha
lora_unet_down_blocks_0_resnets_0_conv1.lora_down.weight
lora_unet_down_blocks_0_resnets_0_conv1.lora_up.weight
lora_unet_down_blocks_0_resnets_0_conv2.alpha
lora_unet_down_blocks_0_resnets_0_conv2.lora_down.weight
lora_unet_down_blocks_0_resnets_0_conv2.lora_up.weight
lora_unet_down_blocks_0_resnets_0_time_emb_proj.alpha
lora_unet_down_blocks_0_resnets_0_time_emb_proj.lora_down.weight
lora_unet_down_blocks_0_resnets_0_time_emb_proj.lora_up.weight
lora_unet_down_blocks_0_resnets_1_conv1.alpha
lora_unet_down_blocks_0_resnets_1_conv1.lora_down.weight
lora_unet_down_blocks_0_resnets_1_conv1.lora_up.weight
lora_unet_down_blocks_0_resnets_1_conv2.alpha
lora_unet_down_blocks_0_resnets_1_conv2.lora_down.weight
lora_unet_down_blocks_0_resnets_1_conv2.lora_up.weight
lora_unet_down_blocks_0_resnets_1_time_emb_proj.alpha
lora_unet_down_blocks_0_resnets_1_time_emb_proj.lora_down.weight
lora_unet_down_blocks_0_resnets_1_time_emb_proj.lora_up.weight
lora_unet_down_blocks_1_attentions_0_proj_in.alpha
lora_unet_down_blocks_1_attentions_0_proj_in.lora_down.weight
lora_unet_down_blocks_1_attentions_0_proj_in.lora_up.weight
lora_unet_down_blocks_1_attentions_0_proj_out.alpha
lora_unet_down_blocks_1_attentions_0_proj_out.lora_down.weight
lora_unet_down_blocks_1_attentions_0_proj_out.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_k.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_k.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_k.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_out_0.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_out_0.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_out_0.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_q.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_q.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_q.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_v.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_v.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn1_to_v.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_k.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_k.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_k.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_out_0.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_out_0.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_out_0.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_q.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_q.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_q.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_v.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_v.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_attn2_to_v.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_ff_net_0_proj.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_ff_net_0_proj.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_ff_net_0_proj.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_ff_net_2.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_ff_net_2.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_0_ff_net_2.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_k.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_k.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_k.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_out_0.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_out_0.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_out_0.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_q.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_q.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_q.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_v.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_v.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn1_to_v.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_k.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_k.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_k.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_out_0.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_out_0.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_out_0.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_q.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_q.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_q.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_v.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_v.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_attn2_to_v.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_ff_net_0_proj.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_ff_net_0_proj.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_ff_net_0_proj.lora_up.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_ff_net_2.alpha
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_ff_net_2.lora_down.weight
lora_unet_down_blocks_1_attentions_0_transformer_blocks_1_ff_net_2.lora_up.weight
lora_unet_down_blocks_1_attentions_1_proj_in.alpha
lora_unet_down_blocks_1_attentions_1_proj_in.lora_down.weight
lora_unet_down_blocks_1_attentions_1_proj_in.lora_up.weight
lora_unet_down_blocks_1_attentions_1_proj_out.alpha
lora_unet_down_blocks_1_attentions_1_proj_out.lora_down.weight
lora_unet_down_blocks_1_attentions_1_proj_out.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_k.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_k.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_k.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_out_0.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_out_0.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_out_0.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_q.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_q.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_q.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_v.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_v.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn1_to_v.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_k.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_k.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_k.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_out_0.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_out_0.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_out_0.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_q.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_q.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_q.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_v.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_v.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_attn2_to_v.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_ff_net_0_proj.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_ff_net_0_proj.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_ff_net_0_proj.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_ff_net_2.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_ff_net_2.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_0_ff_net_2.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_k.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_k.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_k.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_out_0.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_out_0.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_out_0.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_q.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_q.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_q.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_v.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_v.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn1_to_v.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_k.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_k.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_k.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_out_0.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_out_0.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_out_0.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_q.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_q.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_q.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_v.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_v.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_attn2_to_v.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_ff_net_0_proj.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_ff_net_0_proj.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_ff_net_0_proj.lora_up.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_ff_net_2.alpha
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_ff_net_2.lora_down.weight
lora_unet_down_blocks_1_attentions_1_transformer_blocks_1_ff_net_2.lora_up.weight
lora_unet_down_blocks_1_downsamplers_0_conv.alpha
lora_unet_down_blocks_1_downsamplers_0_conv.lora_down.weight
lora_unet_down_blocks_1_downsamplers_0_conv.lora_up.weight

LoRA权重加载

以加载到UNet的LoRA为例:首先获取unet的原始权重W0(state_dict_unet),和LoRA的权重BA(state_dict_lora)。将对应层的权重 W = W 0 + B A W=W0+BA W=W0+BA,即state_dict_unet[name] += state_dict_lora[name]

	def add_lora_to_unet(self, unet: SDUNet, state_dict_lora, alpha=1.0, device="cuda"):
        state_dict_unet = unet.state_dict()
        state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_unet_", alpha=alpha, device=device)
        state_dict_lora = SDUNetStateDictConverter().from_diffusers(state_dict_lora)
        if len(state_dict_lora) > 0:
            for name in state_dict_lora:
                state_dict_unet[name] += state_dict_lora[name].to(device=device)
            unet.load_state_dict(state_dict_unet)

其中convert_state_dict就是将downup进行矩阵乘法得到BA
B A = a l p h a ∗ ( d o w n M a t r i x   @   u p M a t r i x ) BA = alpha∗(downMatrix \ @ \ upMatrix) BA=alpha(downMatrix @ upMatrix)

注意:这里的LoRA不仅有Linear的LoRA,还包含Conv2d的LoRA

	def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0, device="cuda"):
        special_keys = {
            "down.blocks": "down_blocks",
            "up.blocks": "up_blocks",
            "mid.block": "mid_block",
            "proj.in": "proj_in",
            "proj.out": "proj_out",
            "transformer.blocks": "transformer_blocks",
            "to.q": "to_q",
            "to.k": "to_k",
            "to.v": "to_v",
            "to.out": "to_out",
        }
        state_dict_ = {}
        for key in state_dict:
            if ".lora_up" not in key:
                continue
            if not key.startswith(lora_prefix):
                continue
            weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
            weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
            # lcm lora have alpha
            alpha = state_dict[key.replace(".lora_up.weight", ".alpha")].to(device="cuda", dtype=torch.float16)
            print(key, "@", key.replace(".lora_up", ".lora_down"))
            print(alpha, weight_up.shape, weight_down.shape)
            if len(weight_up.shape) == 4:
                weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
                weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
                if len(weight_up.shape) == len(weight_down.shape):  # for Linear weight
                    lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
                else:  # for Conv2d weight
                    lora_weight = alpha * torch.einsum('a b, b c h w -> a c h w', weight_up, weight_down)
            else:
                lora_weight = alpha * torch.mm(weight_up, weight_down)
            target_name = key.split(".")[0].replace("_", ".")[len(lora_prefix):] + ".weight"
            for special_key in special_keys:
                target_name = target_name.replace(special_key, special_keys[special_key])
            state_dict_[target_name] = lora_weight.cpu()
        return state_dict_
  • 26
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Yuezero_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值