论文阅读与源码解析:Conv-LoRA: Convolution Meets LoRA: Parameter Efficient Finetuning for Segment Anything Model (ICLR 2024)
论文地址:https://arxiv.org/abs/2401.17868
GitHub项目地址:https://github.com/autogluon/autogluon/tree/master/examples/automm/Conv-LoRA
源码:https://github.com/autogluon/autogluon/issues/4165
Motivation
预训练SAM微调下游任务以提高性能
现有的工作未能分析或解决 SAM 固有的某些限制。
- SAM的图像编码器是一个普通的ViT,缺乏对密集预测有用的特定于视觉的归纳偏差。
- SAM的预训练本质上是一个二进制掩码预测任务,低级掩码预测预训练阻碍了 SAM 捕获对多类语义分割等任务至关重要的高级图像语义信息的能力。
目的:解决上述限制,仍然保留 SAM 在预训练期间获得的有价值的分割知识,我们在冻结大部分 SAM 的预训练权重的同时微调一小组(额外)模型参数,微调模块能增强SAM编码器与图像相关的局部先验,促进高级语义信息的获取。
Method
遵循LoRA设计策略,在预训练权重旁边添加一个旁路
- 将输入降维
- 将输入上采样,经过不同的专家,专家本质就是一个3x3卷积操作,然后下采样作为这个专家的输出
- 由门控网络选择前k个的专家按比例相加作为输出
- 将输出升回原来的维度
ConvLoRA模块作用:
- 将细小的可训练线性投影层引入 SAM 编码器的每个变换器层,从而帮助恢复其提取高级语义信息的能力。
- Conv-LoRA 在其瓶颈结构中集成了轻量级卷积层。卷积可以通过局部空间操作引入与图像相关的局部先验(即像素与其相邻像素的相关性更强)
- 结合了多个并行卷积专家,每个专家都专门针对不同的特征尺度。将局部先验注入到图像特征的适当尺度中。
源码解读
# Modified based on https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import math
from typing import List, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
def identity(x):
return x
class LoRALayer:
"""
Abstract LoRA Layer.
Parameters
----------
r
rank r of the low-rank decomposition.
lora_alpha
Scaling factor. Can be simply set to same value as r as initialization is scaled already.
merge_weights
Merging weights during inference to reduce latency.
References
----------
1. Edward J. Hu*, Yelong Shen*, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen,
"LoRA: Low-Rank Adaptation of Large Language Models", 2021
https://arxiv.org/abs/2106.09685
2. Code: https://github.com/microsoft/LoRA
"""
def __init__(
self,
r: int,
lora_alpha: int,
lora_dropout: float,
merge_weights: bool,
):
self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
if lora_dropout > 0.0:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = identity
# Mark the weight as unmerged
self.merged = False
self.merge_weights = merge_weights
class ConvLoRALinear(nn.Linear, LoRALayer):
"""
Conv-LoRA incorporated in Linear Layer. Weights of linear layer are set to be frozen per default.
Parameters
----------
in_features
input dimension, set to the original linear layer input dimension LoRA is replacing.
out_features
output dimension, set to the original linear layer output dimension LoRA is replacing.
r
rank r of the low-rank decomposition.
lora_alpha
Scaling factor. Can be simply set to same value as r as initialization is scaled already.
lora_dropout
Dropout probability.
fan_in_fan_out
Set this to True if the layer to replace stores weight like (fan_in, fan_out).
merge_weights
Merging weights during inference to reduce latency.
conv_lora_expert_num
The number of experts in MoE-Conv.
References
----------
1. Zihan Zhong, Zhiqiang Tang, Tong He, Haoyang Fang, Chun Yuan,
"Convolution Meets LoRA: Parameter Efficient Finetuning for Segment Anything Model", 2024
https://arxiv.org/abs/2401.17868
"""
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights: bool = False,
conv_lora_expert_num: Optional[int] = None,
**kwargs,
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)
self.fan_in_fan_out = fan_in_fan_out
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix
self.weight.requires_grad = False
# MoE-Conv
topk = 1
self.lora_moe_gating = MoEGate(M=conv_lora_expert_num, d=self.r, K=topk)
self.lora_moe_experts = nn.ModuleList([])
self.upsample_ratios = list(range(1, conv_lora_expert_num + 1))
for upsample_ratio in self.upsample_ratios:
# 卷积专家本质就是3x3卷积加上一个非线性函数
expert = nn.Conv2d(in_channels=r, out_channels=r, kernel_size=3, stride=1, padding=1, bias=True)
expert.bias.data.zero_()
self.lora_moe_experts.append(nn.Sequential(expert, nn.GELU()))
self.num_experts = conv_lora_expert_num
self.multiply_by_gates = False
self.reset_parameters()
if fan_in_fan_out:
self.weight.data = self.weight.data.T
def reset_parameters(self):
nn.Linear.reset_parameters(self)
if hasattr(self, "lora_A"):
# initialize A the same way as the default for nn.Linear and B to zero
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def T(self, w):
return w.T if self.fan_in_fan_out else w
def forward(self, x: torch.Tensor):
# 输入与原来的固定权重做Linear映射
result = F.linear(x, self.T(self.weight), bias=self.bias)
if self.r > 0:
# 对输入tensor经过Linear把通道维度降低
lora_res = self.lora_dropout(x) @ self.lora_A.T
# 转化成4维(b,c,h,w)
dim = lora_res.dim()
if dim == 3:
B, L, C = lora_res.size()
H = W = int(math.sqrt(L))
lora_res = lora_res.reshape(B, H, W, C)
else:
H, W = lora_res.size()[1:3]
# Calculate the gating values.
lora_res = lora_res.permute(0, 3, 1, 2).contiguous()
# 获得门控网络,并返回一个loss,避免门控网络训练时总是选择较好的那几个
gates, moe_loss = self.lora_moe_gating(lora_res)
# Distribute data samples to experts.
dispatcher = SparseDispatcher(self.num_experts, gates)
expert_inputs = dispatcher.dispatch(lora_res)
expert_outputs = []
# 经过不同的专家
for i in range(self.num_experts):
if len(expert_inputs[i]) == 0:
continue
upsample_ratio = self.upsample_ratios[i]
cur_res = expert_inputs[i]
# 插值方式上采样
if upsample_ratio != 1:
cur_res = F.interpolate(cur_res, scale_factor=upsample_ratio, mode="bicubic")
# 对上采样的tensor输入到卷积专家中处理
cur_res = self.lora_moe_experts[i](cur_res)
# 下采样回原来的尺度大小
if upsample_ratio != 1:
cur_res = F.interpolate(cur_res, size=(int(H), int(W)), mode="bicubic")
expert_outputs.append(cur_res)
# Combine data samples after processing by each expert.
temp_lora_res = dispatcher.combine(expert_outputs, multiply_by_gates=self.multiply_by_gates)
# 残差
lora_res = lora_res + temp_lora_res
lora_res = lora_res.permute(0, 2, 3, 1).contiguous()
if dim == 3:
lora_res = lora_res.reshape(B, L, C)
# 线性层升回原来的通道维度,再与原来的结果相加
result += (lora_res @ self.lora_B.T) * self.scaling
return result, moe_loss
class MoEGate(nn.Module):
def __init__(self, d, M=4, K=1, noisy_gating=True):
"""Constructor
Args:
d: input channel dimensionality.
M: the number of experts.
K: the number of chosen experts for each forward pass.
"""
super(MoEGate, self).__init__()
self.M = M
self.k = K
self.gap = nn.AdaptiveAvgPool2d((1, 1)) # global average pooling
self.noisy_gating = noisy_gating
self.w_gate = nn.Parameter(torch.zeros(d, M), requires_grad=True)
self.w_noise = nn.Parameter(torch.zeros(d, M), requires_grad=True)
self.softplus = nn.Softplus()
self.softmax = nn.Softmax(1)
self.register_buffer("mean", torch.tensor([0.0]))
self.register_buffer("std", torch.tensor([1.0]))
assert self.k <= self.M
def forward(self, feats, loss_coef=1e-2, noise_epsilon=1e-2):
batch_size = feats.shape[0]
feats_S = self.gap(feats).view(batch_size, -1)
clean_logits = feats_S @ self.w_gate
if self.noisy_gating and self.training:
raw_noise_stddev = feats_S @ self.w_noise
noise_stddev = self.softplus(raw_noise_stddev) + noise_epsilon
noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev)
logits = noisy_logits
else:
logits = clean_logits
top_logits, top_indices = logits.topk(min(self.k + 1, self.M), dim=1)
top_k_logits = top_logits[:, : self.k]
top_k_indices = top_indices[:, : self.k]
top_k_gates = self.softmax(top_k_logits)
zeros = torch.zeros_like(logits, requires_grad=True).float()
gates = zeros.scatter(1, top_k_indices, top_k_gates).to(logits.dtype)
if self.noisy_gating and self.k < self.M and self.training:
load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0)
else:
load = self._gates_to_load(gates)
importance = gates.sum(0)
loss = self.cv_squared(importance) + self.cv_squared(load)
loss *= loss_coef
return gates, loss
def _gates_to_load(self, gates):
"""Compute the true load per expert, given the gates.
The load is the number of examples for which the corresponding gate is >0.
Args:
gates: a `Tensor` of shape [batch_size, n]
Returns:
a float32 `Tensor` of shape [n]
"""
return (gates > 0).sum(0)
def cv_squared(self, x):
"""The squared coefficient of variation of a sample.
Useful as a loss to encourage a positive distribution to be more uniform.
Epsilons added for numerical stability.
Returns 0 for an empty Tensor.
Args:
x: a `Tensor`.
Returns:
a `Scalar`.
"""
eps = 1e-10
if x.shape[0] == 1:
return torch.tensor([0], device=x.device, dtype=x.dtype)
return x.float().var() / (x.float().mean() ** 2 + eps)
def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values):
"""Helper function to NoisyTopKGating.
Computes the probability that value is in top k, given different random noise.
This gives us a way of backpropagating from a loss that balances the number
of times each expert is in the top k experts per example.
In the case of no noise, pass in None for noise_stddev, and the result will
not be differentiable.
Args:
clean_values: a `Tensor` of shape [batch, n].
noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus
normally distributed noise with standard deviation noise_stddev.
noise_stddev: a `Tensor` of shape [batch, n], or None
noisy_top_values: a `Tensor` of shape [batch, m].
"values" Output of tf.top_k(noisy_top_values, m). m >= k+1
Returns:
a `Tensor` of shape [batch, n].
"""
batch = clean_values.size(0)
m = noisy_top_values.size(1)
top_values_flat = noisy_top_values.flatten()
threshold_positions_if_in = torch.arange(batch, device=clean_values.device) * m + self.k
threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1)
is_in = torch.gt(noisy_values, threshold_if_in)
threshold_positions_if_out = threshold_positions_if_in - 1
threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1)
# is each value currently in the top k.
normal = Normal(self.mean, self.std)
prob_if_in = normal.cdf((clean_values - threshold_if_in) / noise_stddev)
prob_if_out = normal.cdf((clean_values - threshold_if_out) / noise_stddev)
prob = torch.where(is_in, prob_if_in, prob_if_out)
return prob
class SparseDispatcher(object):
"""Helper for implementing a mixture of experts.
The purpose of this class is to create input minibatches for the
experts and to combine the results of the experts to form a unified
output tensor.
There are two functions:
dispatch - take an input Tensor and create input Tensors for each expert.
combine - take output Tensors from each expert and form a combined output
Tensor. Outputs from different experts for the same batch element are
summed together, weighted by the provided "gates".
The class is initialized with a "gates" Tensor, which specifies which
batch elements go to which experts, and the weights to use when combining
the outputs. Batch element b is sent to expert e iff gates[b, e] != 0.
The inputs and outputs are all two-dimensional [batch, depth].
Caller is responsible for collapsing additional dimensions prior to
calling this class and reshaping the output to the original shape.
See common_layers.reshape_like().
Example use:
gates: a float32 `Tensor` with shape `[batch_size, num_experts]`
inputs: a float32 `Tensor` with shape `[batch_size, input_size]`
experts: a list of length `num_experts` containing sub-networks.
dispatcher = SparseDispatcher(num_experts, gates)
expert_inputs = dispatcher.dispatch(inputs)
expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)]
outputs = dispatcher.combine(expert_outputs)
The preceding code sets the output for a particular example b to:
output[b] = Sum_i(gates[b, i] * experts[i](inputs[b]))
This class takes advantage of sparsity in the gate matrix by including in the
`Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`.
References
----------
1. Noam Shazeer, Azalia Mirhoseini, Krzysztof Maziarz, Andy Davis, Quoc Le, Geoffrey Hinton, Jeff Dean,
"Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer", 2017
https://arxiv.org/abs/1701.06538
2. Code: https://github.com/davidmrau/mixture-of-experts/blob/master/moe.py
"""
def __init__(self, num_experts, gates):
"""Create a SparseDispatcher."""
self._gates = gates
self._num_experts = num_experts
# sort experts
sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0)
# drop indices
_, self._expert_index = sorted_experts.split(1, dim=1)
# get according batch index for each expert
self._batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1], 0]
# calculate num samples that each expert gets
self._part_sizes = (gates > 0).sum(0).tolist()
# expand gates to match with self._batch_index
gates_exp = gates[self._batch_index.flatten()]
self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index)
def dispatch(self, inp):
"""Create one input Tensor for each expert.
The `Tensor` for a expert `i` contains the slices of `inp` corresponding
to the batch elements `b` where `gates[b, i] > 0`.
Args:
inp: a `Tensor` of shape "[batch_size, <extra_input_dims>]`
Returns:
a list of `num_experts` `Tensor`s with shapes
`[expert_batch_size_i, <extra_input_dims>]`.
"""
# assigns samples to experts whose gate is nonzero
# expand according to batch index so we can just split by _part_sizes
inp_exp = inp[self._batch_index].squeeze(1) # [bs * num_of chosen experts, dim]
return torch.split(inp_exp, self._part_sizes, dim=0)
def combine(self, expert_out, multiply_by_gates=True):
"""Sum together the expert output, weighted by the gates.
The slice corresponding to a particular batch element `b` is computed
as the sum over all experts `i` of the expert output, weighted by the
corresponding gate values. If `multiply_by_gates` is set to False, the
gate values are ignored.
Args:
expert_out: a list of `num_experts` `Tensor`s, each with shape
`[expert_batch_size_i, <extra_output_dims>]`.
multiply_by_gates: a boolean
Returns:
a `Tensor` with shape `[batch_size, <extra_output_dims>]`.
"""
# apply exp to expert outputs, so we are not longer in log space
stitched = torch.cat(expert_out, 0).exp()
if multiply_by_gates:
# stitched = stitched.mul(self._nonzero_gates)
stitched = stitched.mul(self._nonzero_gates.unsqueeze(-1).unsqueeze(-1))
# zeros = torch.zeros(self._gates.size(0), expert_out[-1].size(1), requires_grad=True, device=stitched.device)
zeros = torch.zeros(
self._gates.size(0),
expert_out[-1].size()[1],
expert_out[-1].size()[2],
expert_out[-1].size()[3],
requires_grad=True,
device=stitched.device,
)
# combine samples that have been processed by the same k experts
combined = zeros.index_add(0, self._batch_index, stitched.float())
# add eps to all zero values in order to avoid nans when going back to log space
combined[combined == 0] = np.finfo(float).eps
# back to log space
return combined.log()
def expert_to_gates(self):
"""Gate values corresponding to the examples in the per-expert `Tensor`s.
Returns:
a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32`
and shapes `[expert_batch_size_i]`
"""
# split nonzero gates for each expert
return torch.split(self._nonzero_gates, self._part_sizes, dim=0)
# back to log space
return combined.log()
def expert_to_gates(self):
"""Gate values corresponding to the examples in the per-expert `Tensor`s.
Returns:
a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32`
and shapes `[expert_batch_size_i]`
"""
# split nonzero gates for each expert
return torch.split(self._nonzero_gates, self._part_sizes, dim=0)