basenji框架代码解密

basenji_train.py:

代码运行流程:根据params_small.json文件获取模型参数与训练参数,然后使用seqnn.SeqNN类构建模型,然后使用trainer.Trainer类构建seqnn_trainer,以对模型进行训练,然后通过seqnn_trainer调用compile与fit函数执行训练。

读入json文件

使用以下代码读取params_small.json文件,将模型和训练的参数传给basenji_train.py的params变量

 with open(params_file) as params_open:
   params = json.load(params_open)
 params_model = params['model'] # model参数也是一个字典
 params_train = params['train']# train参数也是一个字典

params_model内容如下:
(其中trunk字典是模型每一层结构的参数,后面将被设置为seqnn_model实例的属性)

"model": {
    "seq_length": 131072,
    "target_length": 1024,

    "augment_rc": true,
    "augment_shift": 3,

    "activation": "gelu",
    "batch_norm": true,
    "bn_momentum": 0.9,

    "trunk": [
        {
            "name": "conv_block",
            "filters": 64,
            "kernel_size": 15,
            "pool_size": 8
        },
        {
            "name": "conv_tower",
            "filters_init": 64,
            "filters_mult": 1.125,
            "kernel_size": 5,
            "pool_size": 4,
            "repeat": 2
        },
        {
            "name": "dilated_residual",
            "filters": 32,
            "rate_mult": 2,
            "repeat": 6,
            "dropout": 0.25
        },
        {
            "name": "conv_block",
            "filters": 64,
            "dropout": 0.05
        }
    ],
    "head": {
        "name": "dense",
        "units": 3,
        "activation": "softplus"
    }
}

构建模型seqnn_model

接下来将params_model传递给seqnn.SeqNN来构建模型,所用命令如下:

seqnn_model = seqnn.SeqNN(params_model) # line:104
# seqnn_model的属性包含params_model中的参数,
# 此外可以用seqnn_model.model.summary()查看模型的信息。

该类存放于seqnn.py文件中,其方法有:

  1. init
  def __init__(self, params):
    self.set_defaults()
    for key, value in params.items():
      self.__setattr__(key, value) # 将params里的属性设置为该类的实例
    self.build_model()
    self.ensemble = None
    self.embed = None

params参数即params_model字典,该构造函数将params_model中的键值对设置成seqnn_model实例的属性。

  1. build_blocks
    def build_block(self, current, block_params)
    参数1:current,即输入,由tf.keras.Input生成,并不带有实际的数据,
    参数2:block_params,字典形式,即params_model[‘trunk’]字典,这里面存放的是对模型的每一层的参数定义
    功能:使用blocks.py中所定义的block(即卷积,全连接等操作)来对输入进行操作,并返回一个current
  2. build_model
    def build_model(self, save_reprs=False)
    该函数依次读取self.trunk中的参数,然后调用build_block函数构建对应的网络结构。

构建seqnn_trainer

seqnn_trainer = trainer.Trainer(params_train, train_data, eval_data, options.out_dir)

编译并训练模型

代码如下所示:

seqnn_trainer.compile(seqnn_model) 
seqnn_trainer.fit(seqnn_model)
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值