fairseq框架下代码分析---模型构建(guided_transformer)

文章所描述的模型结构如图所示,guidance signal 可以理解为prompt,encoder的两个部分结构是一样的,都可以使用标准的transformerencoder,decoder端有所不同。decoder layer是guieded transformer decoder,可以看出来,每一个decoder layer中包含四个部分,1)使用output embedding或者previous embedding的self-attention,2)使用guidance encoder output作为k,v的cross -attention,3)使用source encoder output作为k,v的cross-attention.4) Feed forward 层.1. 首先定义的类是整个模型的类GuidedTransformerModel(FairseqEncoderDecoderMoel)通过@register的方式注册模型,模型继承FairseqEncoderDecoderModel

@register_model("guided_transformer")
class GuidedTransformerModel(FairseqEncoderDecoderModel):

1)定义init函数

    def __init__(self, args, encoder, decoder):
        super().__init__(encoder, decoder)
        self.args = args
        self.supports_align_args = True

2)定义add_args(parser):

    def add_args(parser):
        """Add model-specific arguments to the parser."""
        # fmt: off
        parser.add_argument('--activation-fn',
                            choices=utils.get_available_activation_fns(),
                            help='activation function to use')

3)定义build_model(cls,args,task)
类方法中可以通过使用cls来实例化一个对象。
cls举例:

class Person(object):
    def __init__(self, name, age):
        self.name = name
        self.age = age
        print('self:', self)

    # 定义一个build方法,返回一个person实例对象,这个方法等价于Person()。
    @classmethod
    def build(cls):
        # cls()等于Person()
        p = cls("Tom", 18)
        print('cls:', cls)
        return p


if __name__ == '__main__':
    person = Person.build()
    print(person, person.name, person.age)

task是在train.py 中调用fairseq的task=tasks.setup_task(args)得到的,包含字典等
首先调用base_architecture(args)函数,将args读入属性中

@register_model_architecture("guided_transformer", "guided_transformer")
def base_architecture(args):
    args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
    args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
    args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048)

定义build_embedding(dictionary,embed_dim,path=None)
并分类讨论是否share_all_embeddings得到embed_tokens

最后再build_model中用这样的语句返回实例化的对象:

        encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
        decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
        return cls(args, encoder, decoder)
        

在建立模型时可以直接调用build_model

model = task.build_model(args)

4)定义build_encoder和build_decoder

    @classmethod
    def build_encoder(cls, args, src_dict, embed_tokens):
        return GuidedTransformerEncoder(args, src_dict, embed_tokens)

    @classmethod
    def build_decoder(cls, args, tgt_dict, embed_tokens):
        return GuidedTransformerDecoder(
            args,
            tgt_dict,
            embed_tokens,
            no_encoder_attn=getattr(args, "no_cross_attention", False),
        )

其中,而用@staticmethod或@classmethod,就可以不需要实例化,直接类名.方法名()来调用。

5)定义forward函数

2. 定义GuidedTransformerEncoder类
GuidedTransformerEncoderr由多个encoderlayer组成

class GuidedTransformerEncoder(FairseqEncoder):

1)在 def init(self, args, dictionary, embed_tokens):中定义dropout,padding_idx,max_source_positions等参数,还要定义embed_positions,layers等参数
其中,用到的PositionalEmbedding和LayerNorm都来自fairseq.modules
2)定义forward_embedding函数
根据是否embed_positions以及是否layernorm_embedding得到最终的表达
3)定义forward函数
得到src_tokens的embedding
得到encoder_padding_mask
循环几个layers,运行各个encoderlayer

最后返回

        return EncoderOut(
            encoder_out=x,  # T x B x C
            encoder_padding_mask=encoder_padding_mask,  # B x T
            encoder_embedding=encoder_embedding,  # B x T x C
            encoder_states=encoder_states,  # List[T x B x C]
        )

EncoderOut是在fairseqEncoder中定义好的tuple

EncoderOut = NamedTuple(
    "EncoderOut",
    [
        ("encoder_out", Tensor),  # T x B x C
        ("encoder_padding_mask", Tensor),  # B x T
        ("encoder_embedding", Tensor),  # B x T x C
        ("encoder_states", Optional[List[Tensor]]),  # List[T x B x C]
    ],
)

3.定义GuidedTransformerDecoder(FairSeIncrementalDecoder)
1)在def __init__中定义input_embed_dim,decoder_dim,output_dim,padding_idx,embed_tokens(Embedding方法),embed_positions(Position Embedding方法),dim映射方法,layers,等方法
2)在forward中

def forward(
        self,
        prev_output_tokens,#前一步decoder输出
        encoder_out: Optional[EncoderOut] = None,#传入None类型或者传入EncoderOut类型
        z_encoder_out: Optional[EncoderOut] = None,

调用embedding方法,将decoder端的输入进行嵌入
对encoder_out得到的输入进行预处理
依次通过各层decoderlayers
返回结果如下:

        return x, {"attn": [attn], "inner_states": inner_states}
  1. 定义TransformerEncoderLayer
    TransformerEncoderLayer在fairseq modules中的transformer_layer文件中
    1)定义init函数
    init函数中定义embed_dim
    定义self_attn如下:
        self.self_attn = MultiheadAttention(
            self.embed_dim,
            args.encoder_attention_heads,
            dropout=args.attention_dropout,
            self_attention=True,
        )

MultiheadAttention来自fairseq.modules模块
定义self_attn_layer_norm层如下:

self.self_attn_layer_norm = LayerNorm(self.embed_dim)

其中LayerNorm层也来自fairseq_modules模块

2)定义forward函数

首先使用self_attn层计算attention值`

        x, _ = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=encoder_padding_mask,
            attn_mask=attn_mask,
        )

进行layer_norm
进行dropout
使用残差
连接f1全连接层,调用激活函数

x = self.activation_fn(self.fc1(x))

进行dropout
连接f2全连接层
进行dropout
使用残差

  1. 定义GuidedTransformerDecoderLayer
    1)def init
    定义embed_dim
    定义在guidedtransformer中特有的cross self attention
    定义self_attn层如下:
        self.self_attn = MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=args.decoder_attention_heads,
            dropout=args.attention_dropout,
            add_bias_kv=add_bias_kv,
            add_zero_attn=add_zero_attn,
            self_attention=not self.cross_self_attention,
        )

定义dropout,定义激活函数,定义layernorm层,
定义综合encoder端source document的attention层:

            self.encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=getattr(args, "encoder_embed_dim", None),
                vdim=getattr(args, "encoder_embed_dim", None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )
            self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

定义综合encoder端guidance signal的attention层:

            self.z_encoder_attn = MultiheadAttention(
                self.embed_dim,
                args.decoder_attention_heads,
                kdim=getattr(args, "encoder_embed_dim", None),
                vdim=getattr(args, "encoder_embed_dim", None),
                dropout=args.attention_dropout,
                encoder_decoder_attention=True,
            )
            self.z_encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export)

kdim和vdim都是encoder_embed_dim, 因为q,k,v中的k和v是来自encoder,q来自decoder

定义fc1,fc2等

2)定义forward函数
依次通过定义的self-attention层,2个cross-attention层,以及最后的forward层。

  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
回答: 当使用guided-diffusion预训练模型进行采样时,可能会遇到报错的情况。根据提供的引用内容,我没有找到直接与guided-diffusion预训练模型采样报错相关的信息。然而,根据引用\[1\]中提到的模型更换方法和引用\[2\]中提到的训练hypernets的方式,您可以尝试以下几个步骤来解决报错问题: 1. 确保您已正确设置模型的路径和文件位置。根据引用\[1\]中的说明,您可以将需要的模型移入指定的文件夹,或使用ChangeModel函数更换模型的路径。 2. 检查模型的训练方式和参数设置是否正确。根据引用\[2\]中提到的训练hypernets的方式,确保您在训练模型时使用了正确的学习率和训练方式。 3. 确保您使用的embedding模型与训练该embedding时的模型保持一致。根据引用\[3\]中的说明,使用embedding生成新的图片时,最好和训练这个embedding时的模型保持一致,以确保生成效果良好。 如果您仍然遇到报错问题,建议您查看相关的文档、教程或寻求更专业的技术支持来解决该问题。 #### 引用[.reference_title] - *1* *2* *3* [Stable Diffusion攻略集(Stable Diffusion官方文档、kaggle notebook、webui资源帖)](https://blog.csdn.net/qq_56591814/article/details/128385416)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值