Briefings in bioinformatics2022 | FP-GNN+:用于分子性质预测的versatile DL架构 【train.py文件代码逐行解释】

这段代码实现了一个训练函数 training,用于在指定数据集上训练 FP-GNN 模型。下面逐行解释代码:

from argparse import Namespace
from logging import Logger
import numpy as np
import os
from fpgnn.train import fold_train
from fpgnn.tool import set_log, set_train_argument, get_task_name, mkdir

这段代码导入了所需的 Python 模块,包括 argparseloggingnumpyos,以及自定义的 fold_trainset_logset_train_argumentget_task_name 和 mkdir 等函数。

def training(args,log):
    info = log.info
    
    seed_first = args.seed
    data_path = args.data_path
    save_path = args.save_path
    
    score = []

这段代码定义了一个名为 training 的函数,该函数接受两个参数:args 和 logargs 是一个 Namespace 类型的对象,其中包含了训练 FP-GNN 模型所需的所有参数;log 是一个 Logger 类型的对象,用于记录训练过程中的日志信息。info = log.info 表示将 log 对象中的 info 方法赋值给 info 变量,以便在后续代码中使用。

    for num_fold in range(args.num_folds):
        info(f'Seed {args.seed}')
        args.seed = seed_first + num_fold
        args.save_path = os.path.join(save_path, f'Seed_{args.seed}')
        mkdir(args.save_path)
        
        fold_score = fold_train(args,log)
        
        score.append(fold_score)
    score = np.array(score)

这段代码使用 for 循环对数据集进行交叉验证,即将数据集分成若干个子集,每次使用其中一个子集作为测试集,其余子集作为训练集。range(args.num_folds) 表示循环次数等于交叉验证的折数。info(f'Seed {args.seed}') 表示记录当前交叉验证的随机种子。args.seed = seed_first + num_fold 表示将随机种子设为当前交叉验证的编号加上初始随机种子,以保证每次交叉验证使用的随机数序列不同。args.save_path = os.path.join(save_path, f'Seed_{args.seed}') 表示设置模型保存路径,该路径包含了随机种子信息,以便于区分不同的模型。mkdir(args.save_path) 表示创建模型保存路径。

接下来,调用 fold_train(args,log) 函数进行一次交叉验证的训练,并将训练得到的分数保存到 fold_score 变量中。最后,将 fold_score 添加到 score 列表中,并将 score 转换为 numpy 数组。

    info(f'Running {args.num_folds} folds in total.')
    if args.num_folds > 1:
        for num_fold, fold_score in enumerate(score):
            info(f'Seed {seed_first + num_fold} : test {args.metric} = {np.nanmean(fold_score):.6f}')
            if args.task_num > 1:
                for one_name,one_score in zip(args.task_names,fold_score):
                    info(f'    Task {one_name} {args.metric} = {one_score:.6f}')
    ave_task_score = np.nanmean(score, axis=1)
    score_ave = np.nanmean(ave_task_score)
    score_std = np.nanstd(ave_task_score)
    info(f'Average test {args.metric} = {score_ave:.6f} +/- {score_std:.6f}')
    
    if args.task_num > 1:
        for i,one_name in enumerate(args.task_names):
            info(f'    average all-fold {one_name} {args.metric} = {np.nanmean(score[:, i]):.6f} +/- {np.nanstd(score[:, i]):.6f}')
    
    return score_ave,score_std

这段代码计算交叉验证的平均分数,并将平均分数和标准差记录到日志中。info(f'Running {args.num_folds} folds in total.') 表示记录总共进行了多少次交叉验证。如果折数大于 1,就使用 for 循环遍历每一次交叉验证的结果,并记录测试分数。如果任务数大于 1,就遍历每一个任务,并记录平均分数和标准差。

最后,计算所有交叉验证的平均分数和标准差,并将结果记录到日志中。如果任务数大于 1,就遍历每一个任务,并记录所有交叉验证的平均分数和标准差。

if __name__ == '__main__':
    args = set_train_argument()
    log = set_log('train',args.log_path)
    training(args,log)

这段代码表示当该脚本作为主程序运行时,先调用 set_train_argument() 函数获取训练参数,再调用 set_log() 函数创建日志对象,最后调用 training() 函数进行模型训练。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值