baseline代码
对于siRNA数据集进行深度学习的模型训练框架代码学习
本段baseline代码分为十个部分
1.依赖库的导入
导入代码实现需要的库,用于文件操作、深度学习、数据处理、模型评估
并建立了自定义函数确保随机数生成的可复现性与可控性
2.创建基因组分词器
将基因组序列大写化,并分割成片段化的n-gram
其中尾端多余的不符合设定片段长的碱基序列会被舍弃
返回结果为处理后的分词列表
3.创建基因组词汇表
使用类初始化__init__方法创建词汇到索引,索引到词汇的映射
使用create类方法,统计每个对象出现的频率并进行排序,并设定进入词汇表的对象的最大容量和最低频率,进行部分保留进入词汇表
此步骤的原因有其分子遗传学解释,在长段发挥作用的RNA中,出现频率较高的片段通常具有较高的性状表现,很大可能是RNA作用效果片段
4.siRNA数据集转换
首先使用__init__方法初始化数据集
__len__方法返回数据集长度
getitem方法获取指定索引的样本,并将文本数据转换为张量数据
tokenize_and_encode进行分词编码,返回张量格式的填充序列
5.siRNA Model
基于GRU神经网络,是处理siRNA的主体模型
__init__方法初始化模型各层
前向传播方法forward,输入序列进行嵌入并进行Dropout处理,并输出拼接
在逻辑上,输出信息取决于过去信息、当前信息和门
6.评估指标计算函数
计算模型的评估指标
通过平均绝对误差实现对于精确度、召回率和F1得分的计算,并最终返回综合评分
7.模型评估函数
使用测试集评估模型的性能,遍历测试数据,计算评估指标,并返回得分
8.模型训练函数
用于训练模型,不断进行训练-评估的epoch,并在每次epoch之后通过评估,保留最佳模型
一次epoch中,将进行
前向传播:输入数据,得到预测值
计算损失:使用损失函数计算预测值和真实值之间的差异
反向传播:使用损失函数的梯度反向传播到模型的每一层,计算参数梯度
参数更新:使用优化器根据梯度更新参数
9.模型训练的主程序
优化模型的各个参数,使得模型可以根据数据的特征准确预测或者分类
10.进行测试的程序
负责进行预测并保存预测结果
宏观来看,整个代码功能完备明晰
包含数据准备、模型定义、参数初始化、损失函数与优化器、训练过程模型评估机制
题目背景解读
使用siRNA靶向结合病原体mRNA,抑制或阻断该mRNA的表达,是siRNA基因疗法的基本原理,关键在于相似互补序列的碱基互补配对,以及siRNA的特征序列。由于该疗法是涉及微观层面的分子基因反应,临床实验中药物作用疗效鉴定困难且成本巨大,因此该类疗法通常研发周期长,药物疗效不明确,发展方向难以辨别。
使用人工智能模型进行数据学习,迭代优化后可以对于siRNA残留量进行较为准确的预测,实现了低成本短周期的药物疗效判别,将极大促进基因疗法药物的开发与临床实验性应用。
代码来源:Datawhale AI夏令营 模型训练的代码脚本文件