训练部分代码
train部分代码完成
train代码
这部分是服务器的训练的主代码,内容比较多,而且变量的设计复杂,需要仔细的剥开讲讲。
def SL_train(config, modelsORpara):
torch.cuda.set_device(config.device)
roc_datas, prc_datas = [], []
repres_list, label_list = [], []
train_seq, test_seq = [], []
train_label, test_label = [], []
pos_list = []
neg_list = []
best_performance = []
data_statistic = [] # train pos, train neg, test pos, test neg
time_use = []
step_log_interval = []
train_metric_record = []
train_loss_record = []
step_test_interval = []
test_metric_record = []
test_loss_record = []
首先是定义了一大波参数,我们开始顺次标注一下
roc_datas
,prc_datas
:用来存放画图的roc和pr数据repres_list
,label_list
:用来存放模型得倒的特征提取和标签列表train_seq
,test_seq
:测试的sequence和测试的sequencetrain_label
,test_label
:训练集的label和测试集的labelpos_list
,neg_list
:存放模型预测的正负样本的置信度best_performance
:存放模型最好的表现的数组data_statistic
:数据统计的函数统计这几个train pos, train neg, test pos, test neg的数量time_use
:每个模型训练的耗时记录step_log_interval
:用来画epochlog的step的横坐标train_metric_record
:用来画epochlog的其一纵坐标准确度train_loss_record
:用来画epochlog的其一纵坐标损失值(loss)
if_same = config.if_same
if_same = True
savepath = '/data/result/' + config.learn_name
# if not os.path.exists('../data/result/'):
# os.mkdir('../data/result/')
if not os.path.exists(savepath + '/plot'):
os.mkdir(savepath + '/plot')
util_file.filiter_fasta(config.path_data, savepath, skip_first=False)
用来保存数据记录的文件夹的创建,以及切分相应的数据集作为对应的四个小数据集,方便后面的数据统计操作。
names