basenji框架揭秘

basenji_train.py:

代码运行流程:根据params_small.json文件获取模型参数与训练参数,然后使用seqnn.SeqNN类构建模型,然后使用trainer.Trainer类构建seqnn_trainer,以对模型进行训练,然后通过seqnn_trainer调用compile与fit函数执行训练。

读入json文件

使用以下代码读取params_small.json文件,将模型和训练的参数传给basenji_train.py的params变量

 with open(params_file) as params_open:
   params = json.load(params_open)
 params_model = params['model'] # model参数也是一个字典
 params_train = params['train']# train参数也是一个字典

   
   
   
   

    params_model内容如下:
    (其中trunk字典是模型每一层结构的参数,后面将被设置为seqnn_model实例的属性)

    "model": {
        "seq_length": 131072,
        "target_length": 1024,
    
    "augment_rc": true,
    "augment_shift": 3,
    
    "activation": "gelu",
    "batch_norm": true,
    "bn_momentum": 0.9,
    
    "trunk": [
        {
            "name": "conv_block",
            "filters": 64,
            "kernel_size": 15,
            "pool_size": 8
        },
        {
            "name": "conv_tower",
            "filters_init": 64,
            "filters_mult": 1.125,
            "kernel_size": 5,
            "pool_size": 4,
            "repeat": 2
        },
        {
            "name": "dilated_residual",
            "filters": 32,
            "rate_mult": 2,
            "repeat": 6,
            "dropout": 0.25
        },
        {
            "name": "conv_block",
            "filters": 64,
            "dropout": 0.05
        }
    ],
    "head": {
        "name": "dense",
        "units": 3,
        "activation": "softplus"
    }
    

    }

      构建模型seqnn_model

      接下来将params_model传递给seqnn.SeqNN来构建模型,所用命令如下:

      seqnn_model = seqnn.SeqNN(params_model) # line:104
      # seqnn_model的属性包含params_model中的参数,
      # 此外可以用seqnn_model.model.summary()查看模型的信息。
      
       
       
       
       

        该类存放于seqnn.py文件中,其方法有:

        1. init
          def __init__(self, params):
            self.set_defaults()
            for key, value in params.items():
              self.__setattr__(key, value) # 将params里的属性设置为该类的实例
            self.build_model()
            self.ensemble = None
            self.embed = None
        
         
         
         
         

          params参数即params_model字典,该构造函数将params_model中的键值对设置成seqnn_model实例的属性。

          1. build_blocks
            def build_block(self, current, block_params)
            参数1:current,即输入,由tf.keras.Input生成,并不带有实际的数据,
            参数2:block_params,字典形式,即params_model[‘trunk’]字典,这里面存放的是对模型的每一层的参数定义
            功能:使用blocks.py中所定义的block(即卷积,全连接等操作)来对输入进行操作,并返回一个current
          2. build_model
            def build_model(self, save_reprs=False)
            该函数依次读取self.trunk中的参数,然后调用build_block函数构建对应的网络结构。

          构建seqnn_trainer

          seqnn_trainer = trainer.Trainer(params_train, train_data, eval_data, options.out_dir)
          
           
           
           
           

            编译并训练模型

            代码如下所示:

            seqnn_trainer.compile(seqnn_model) 
            seqnn_trainer.fit(seqnn_model)
            
             
             
             
             
              评论
              添加红包

              请填写红包祝福语或标题

              红包个数最小为10个

              红包金额最低5元

              当前余额3.43前往充值 >
              需支付:10.00
              成就一亿技术人!
              领取后你会自动成为博主和红包主的粉丝 规则
              hope_wisdom
              发出的红包
              实付
              使用余额支付
              点击重新获取
              扫码支付
              钱包余额 0

              抵扣说明:

              1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
              2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

              余额充值