flux.1模型架构探究

Flux.1 pipeline

在这里插入图片描述

纵观Flux.1的整个工作流程和官方提供的源码,我们可以总结出以下几点:

  • 统一提示词处理:Flux没有区分正向和反向提示词。提示词(prompt)一方面通过T5编码器生成主要的引导张量,引导图像生成;另一方面,通过CLIP编码器提取全局语义信息,作为向量(vec)用于生成每一层中的scale(缩放)、shift(偏置)和gate(门控)参数,调节图像和文本的中间变量。
  • 先进的位置编码:Flux采用了更为强大的RoPE(Rotary Position Embedding)位置编码,通过干预每一层中的自注意力机制(self-attention)的查询(q)和键(k)向量,影响图像生成过程,提高生成质量。
  • 双流与单流层结合:与Stable Diffusion 3类似,Flux主要通过DoubleStreamBlock中的自注意力机制使图像的潜在变量与文本信息融合。不同的是,Flux新增了SingleStreamBlock,将图像和文本的潜在变量简单地拼接(concat)在一起,并进行统一处理,进一步增强信息融合效果

随机噪声(img)和文本编码(txt)

以生成512x512的图片为例,同时T5_tokenizer的参数设置为(max_length=512,padding=True),首先会生成一个形状为[bs, 16, 64, 64]的高斯噪声张量,作为初始的latent表示

return torch.randn(
        num_samples,
        16,
        2 * math.ceil(height / 16),
        2 * math.ceil(width / 16),
        device=device,
        dtype=dtype,        generator=torch.Generator(device=device).manual_seed(seed),
    )

接着,由于Flux内部采用了图块化操作,所以高斯噪声的维度被变换为[bs, 1024, 64],同时保证了第三个维度始终为64。变换之后,相邻的四个像素点组成一个图块,并将16个通道对应的图块上的所有像素合并到一个维度上

img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)

这种变换方式确保了在后续的处理过程中,图像的空间信息能够被有效地保留和利用。经过Flux处理后,维度会重新变回[bs, 16, 64, 64],以便进行VAE解码

Flux直接用了Google的T5模型的encoder部分,T5模型是一种强大的预训练语言模型,能够高效地处理各种自然语言处理任务。编码器的输出txt的维度为[bs, 512, 4096],包含了文本的局部和整体语义信息

self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)

txt = t5(prompt)

在进入模型主体之前,img和txt都需要通过Linear层,将维度分别调整为:

  • img: [bs, 1024, 64] ==> [bs, 1024, 3072]
  • txt: [bs, 512, 4096] ==> [bs, 512, 3072]

这样做的目的是为了使得图像和文本的特征维度一致,便于后续的融合和处理。

pe嵌入(位置编码)

位置编码是Transformer模型中的一个重要组成部分,用于在输入中引入位置信息,使得模型能够感知序列中各个元素的位置。Flux对图像和文本分别进行位置编码

img_ids = torch.zeros(h // 2, w // 2, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)

首先,我们初始化一个形状为(h/2, w/2, 3)的全零张量img_ids,并为每个位置(i, j)的图块编号为(0, i, j)。然后将img_ids从[h, w, c]变换为[1, (h x w), c]

txt_ids = torch.zeros(bs, txt.shape[1], 3)

对于文本,Flux初始化一个形状为[1, 512, 3]的全零张量txt_ids

ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)

奇怪的是,Flux后续并没有对txt_ids做任何更改,直接将img_ids和txt_ids拼接到一起,进入pe_embedder中进行位置编码,这里我的猜想是,prompt已经通过T5模型的编码,所以Flux已经可以从prompt的潜在变量txt中get到token之间的关联信息,所以这一步要做的只有两件事

  • 对图块进行位置编码
  • 区分文字信息和图块信息

而将txt_ids全部置零可以满足这两个要求

之后pe_embedder会分别按照三个维度进行rope编码并将其拼接到一起(n_axes=3),形成一个[bs, 1, 1536, 64, 2, 2]的张量

def forward(self, ids: Tensor) -> Tensor:
    n_axes = ids.shape[-1]
    emb = torch.cat(
        [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
        dim=-3,
    )

    return emb.unsqueeze(1)

vec嵌入

时间步(timestep)和引导(guidance)都是常数,经过timestep_embedding处理后变成了[bs, 256]的张量。同时Flux使用了CLIP模型来获取文本的全局特征向量(y)

y = clip(prompt)
 
vec = self.time_in(timestep_embedding(timestep, 256))
#vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) (schenell中没有guidance_in)
vec = vec + self.vector_in(y)

vec会综合timestepguidanceCLIP的输出y,这里的time_in,vector_in,guidance_in都是Linear层,最终vec的维度为[bs, 3072]

CLIP 的输出 pooler_output(y) 是对输入文本序列进行池化操作后得到的全局特征向量。这个向量表示了整个文本的语义信息,可以用于各种下游任务,如文本分类、相似度计算等。

双流层和单流层

Modulation

在这里插入图片描述

Modulation的输入是vec,输出为两个mod,每个mod中含有(scale, shift, gate),通过这三个参数来对生成过程中的潜在变量进行干预和调制

mlp

在这里插入图片描述

Attention

所有的模块,如Linear,Modulation,mlp,ScaleAndShift都只对张量的最后一维进行操作,即只对img中的单个单元和txt中的单个单元分别做处理,而将各个img单元和txt单元直接建立起联系的只有attention模块

def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
    q, k = apply_rope(q, k, pe)

    x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
    x = rearrange(x, "B H L D -> B L (H D)")

    return x
    
    
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
    xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
    xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
    return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

可以看到,经过rope编码后得到的pe张量,通过影响自注意力机制中的q,k来干预图像的生成

DoubleStreamBlock

双流层块用于同时处理图像和文本的特征,通过交替使用两个流(一个用于图像,一个用于文本)来实现信息的融合和交换,但是事实上,经过self-attention后,图像中已经蕴含了文本信息,文本中也蕴含了图像信息

在这里插入图片描述

SingleStreamBlock

单流层块用于进一步混合图像和文本的信息,并且包含残差连接,确保信息在网络中的流动更加顺畅

在这里插入图片描述

imtxt的维度为[bs, 1024+512, 3072],单流层除了将img和txt的信息混合之外,还出现了残差连接

最后输入进LastLayer之前,又将imtxt[bs]的前1024维取出来,认为其包含着image的完整latent信息

LastLayer

在这里插入图片描述

在最后一层中,img经过层归一化和缩放平移后,通过一个线性层将img的维度变为[1, 1024, 64]。这种处理方式确保了图像特征在生成过程中保持一致性和稳定性

生成方法

        pred = model(
            img=img,
            img_ids=img_ids,
            txt=txt,
            txt_ids=txt_ids,
            y=vec,
            timesteps=t_vec,
            guidance=guidance_vec,
        )

        img = img + (t_prev - t_curr) * pred

img = img + (t_prev - t_curr) * pred可以看出,与传统预测噪声不同,flux是基于流的方法(Flow-based Models),huggingface对flux的集成中用的调度器也是FlowMatchEulerDiscreteScheduler

推理加速

通过深入分析Flux的处理流程,我们发现源码中存在一些重复计算,这些计算实际上是多余的。例如,位置编码(position encoding, pe)、y的嵌入和文本(txt)的嵌入等在每一步预测中都是不变的。然而,源码在每一步预测中都重新计算这些值。这种实现方式可能是为了便于理解代码逻辑。

测试结果显示,self-attention模块仅占总执行时间的22%。因此,虽然针对attention机制的加速方法对本模型有一定效用,但其效用相对有限。

test: 1
the total time taken for the flux: 2.447074680007063
the total time taken for the self_attention: 0.5275727066909894
ratio: 0.21559321871184853

test: 2
the total time taken for the flux: 2.351353910053149
the total time taken for the self_attention: 0.5199191314168274
ratio: 0.2211147922879357

test: 3
the total time taken for the flux: 2.355428950046189
the total time taken for the self_attention: 0.5204954411601648
ratio: 0.22097692275963496

test: 4
the total time taken for the flux: 2.355868450948037
the total time taken for the self_attention: 0.5216755487490445
ratio: 0.22143662076679807

test: 5
the total time taken for the flux: 2.354994199005887
the total time taken for the self_attention: 0.5202859512064606
ratio: 0.22092876127936484

在每一层中,图像(img)和文本(txt)的计算依赖于Modulation(vec)的计算结果,而Modulation(vec)依赖于CLIP(prompt),并不依赖于img和txt。因此,可以在计算img和txt的attention时,利用另一张显卡并行计算后续层中的Modulation(vec),并将结果提前传输到主显卡上。这种方法可以在需要时直接获取Modulation(vec)的计算结果。

进一步优化可以将Flux中的所有Modulation层独立出来放到第二张显卡上,这将有效减少对第二张显卡的内存占用。

另一种加速策略是将双流层分配到两张显卡上分别执行,仅在进行self-attention时进行通信。如果将文本流的计算放置在另一张显卡上,那么其通信量为512 x 9216 x 4字节=18MB。

在本人的主机中,显卡之间仅支持PCIe通信,不同的PCIe通信速度从500MB/s到8GB/s不等。有的显卡支持NVLink,通信速度会有质的飞跃,由于文本流的维度比图像流要小,如果两流之间计算时间的差值大于通信延迟,那么这种方法将是一个理想的加速方案。

由于sd3模型仅包含双流层,因此在sd3模型上加速效果会更加显著。

timestep_embedding和Rope

时间步嵌入(timestep_embedding)是传统方法,而Rope是Flux新引入的方法

timestep_embedding

Create sinusoidal timestep embeddings

def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
    t = time_factor * t
    half = dim // 2
    freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
        t.device
    )

    args = t[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    if torch.is_floating_point(t):
        embedding = embedding.to(t)
    return embedding

这段代码描述的大概是下边这两个公式

f r e q s = e − l o g ( 1000 ) ⋅ [ 1 , 2 , 3.... , h a l f ] freqs=e^{-log(1000)·[1,2,3....,half]} freqs=elog(1000)[1,2,3....,half]

e m b e d d i n g = t o r c h . c a t ( c o s ( t ⋅ f r e q s ) , s i n ( t ⋅ f r e q s ) ) embedding=torch.cat(cos(t·freqs),sin(t·freqs)) embedding=torch.cat(cos(tfreqs),sin(tfreqs))

Rope

Rope(Rotary Position Embedding)是一种新引入的位置编码方法,能够将相对位置信息集成到自注意力机制中,并提升Transformer架构的性能。其实现如下

def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
    scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
    omega = 1.0 / (theta**scale)
    out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
    out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
    out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
    return out.to(dtype=torch.float32, device=pos.device)

pos是ids的某一层张量,比如ids[… , 0] , ids[… , 1],dim和theta都是超参数

这种位置编码方式源自于论文 RoFormer: Enhanced Transformer with Rotary Position Embedding ,目前很火的 LLaMA、GLM 模型也是采用该位置编码方式

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值