学习平台
第二届世界科学智能大赛 物质科学赛道:催化反应产率预测
http://competition.sais.com.cn/competitionDetail/532233/format
比赛内容是通过利用历史催化反应数据,并结合AI技术,可以预测新催化反应的产率,从而有效地帮助科研人员和产业界加快高活性反应条件的筛选速度,减少资源与人力的消耗,促进新物质的创造与合成。
赛题背景
比赛提供在药物合成中常见的多种催化反应实验数据,其中包括反应的底物、包括催化剂在内的反应添加剂、反应溶剂以及反应产物,期待选手通过分析反应数据,利用机器学习、深度学习算法或者大语言模型,建立产率预测模型,从而辅助未知新反应的反应条件筛选。
task1:构建一个能够准确预测碳氮成键反应产率的预测模型
一站式baseline
魔塔notebook,PAI-DSW是为算法开发者量身打造的云端深度学习开发环境,内置JupyterLab、WebIDE及Terminal,无需任何运维配置即可编写。
在cpu环境运行平台中,通过命令行完成数据文件使用
!pip install pandas
!pip install -U scikit-learn
!pip install rdkit
模型运行所需的环境依赖包括
- Python3
- pandas 强大的分析结构化数据的工具集;它的使用基础是Numpy(提供高性能的矩阵运算);用于数据挖掘和数据分析,同时也提供数据清洗功能。
- scikit-learn 基于Numpy, Scipy和matplotlib,包含了大量的机器学习算法实现,包括分类、回归、聚类和降维等,还包含了诸多模型评估及选择的方法。
- rdkit 常用的生物化学信息python工具包。它提供了大量对化学分子2D或3D的计算操作,可生成用于机器学习的分子描述符。
然后进行库的导入
# 首先,导入库
import pickle
import pandas as pd
from tqdm import tqdm
from sklearn.ensemble import RandomForestRegressor
from rdkit.Chem import rdMolDescriptors
from rdkit import RDLogger,Chem
import numpy as np
RDLogger.DisableLog('rdApp.*')
特征提取
这是机器学习模型运行最重要的步骤,机器学习作为较为简单的模型需要对输入数据进行处理。
官方发布数据的相关字段如下rxnid,Reactant1,Reactant2,Product,Additive,Solvent,Yield。其中:
- rxnid 对数据的id标识,无实际意义
- Reactant1 反应物1
- Reactant2 反应物2
- Product 产物
- Additive 添加剂(包括催化剂catalyst等辅助反应物合成但是不对产物贡献原子的部分)
- Solvent 溶剂
- Yield 产率 其中Reactant1,Reactant2,Product,Additive,Solvent都是由SMILES表示。
生成分子指纹(Morgan Fingerprint)描述符,并将其转换为位向量(bit vector)形式。
def mfgen(mol,nBits=2048, radius=2):
'''
Parameters
----------
mol : mol
RDKit mol object.
nBits : int
Number of bits for the fingerprint.
radius : int
Radius of the Morgan fingerprint.
Returns
-------
mf_desc_map : ndarray
ndarray of molecular fingerprint descriptors.
'''
fp = rdMolDescriptors.GetMorganFingerprintAsBitVect(mol,radius=radius,nBits=nBits)
return np.array(list(map(eval,list(fp.ToBitString()))))
def vec_cpd_lst(smi_lst):
smi_set = list(set(smi_lst))
smi_vec_map = {}
for smi in tqdm(smi_set): # tqdm:
mol = Chem.MolFromSmiles(smi)
smi_vec_map[smi] = mfgen(mol)
smi_vec_map[''] = np.zeros(2048)
vec_lst = [smi_vec_map[smi] for smi in smi_lst]
return np.array(vec_lst)
加载训练和测试数据。
dataset_dir = '../dataset' #文件路径
train_df = pd.read_csv(f'{dataset_dir}/round1_train_data.csv')
test_df = pd.read_csv(f'{dataset_dir}/round1_test_data.csv')
print(f'Training set size: {len(train_df)}, test set size: {len(test_df)}')
提取化学反应物的SMILES字符串,转换为分子指纹,然后将这些指纹拼接为一个大的特征向量,用于机器学习模型的训练和测试。
# 数据读取
train_rct1_smi = train_df['Reactant1'].to_list()
train_rct2_smi = train_df['Reactant2'].to_list()
train_add_smi = train_df['Additive'].to_list()
train_sol_smi = train_df['Solvent'].to_list()
# 转化分子指纹
train_rct1_fp = vec_cpd_lst(train_rct1_smi)
train_rct2_fp = vec_cpd_lst(train_rct2_smi)
train_add_fp = vec_cpd_lst(train_add_smi)
train_sol_fp = vec_cpd_lst(train_sol_smi)
# 特征向量拼接
train_x = np.concatenate([train_rct1_fp,train_rct2_fp,train_add_fp,train_sol_fp],axis=1)
train_y = train_df['Yield'].to_numpy()
#
test_rct1_smi = test_df['Reactant1'].to_list()
test_rct2_smi = test_df['Reactant2'].to_list()
test_add_smi = test_df['Additive'].to_list()
test_sol_smi = test_df['Solvent'].to_list()
test_rct1_fp = vec_cpd_lst(test_rct1_smi)
test_rct2_fp = vec_cpd_lst(test_rct2_smi)
test_add_fp = vec_cpd_lst(test_add_smi)
test_sol_fp = vec_cpd_lst(test_sol_smi)
test_x = np.concatenate([test_rct1_fp,test_rct2_fp,test_add_fp,test_sol_fp],axis=1)
随机森林建模
# Model fitting
model = RandomForestRegressor(n_estimators=15,max_depth=10,min_samples_split=2,min_samples_leaf=1,n_jobs=-1)
model.fit(train_x,train_y) # 数据集训练模型
# 保存
with open('./random_forest_model.pkl', 'wb') as file:
pickle.dump(model, file)
# 加载
with open('random_forest_model.pkl', 'rb') as file:
loaded_model = pickle.load(file)
# 预测\推理
test_pred = loaded_model.predict(test_x)
参数解释:
- n_estimators=10: 决策树的个数,越多越好;但是越多意味着计算开销越大;
- max_depth: (default=None)设置树的最大深度,默认为None;
- min_samples_split: 根据属性划分节点时,最少的样本数;
- min_samples_leaf: 叶子节点最少的样本数;
- n_jobs=1: 并行job个数,-1表示使用所有cpu进行并行计算。
最后生成要求的submit
ans_str_lst = ['rxnid,Yield']
for idx,y in enumerate(test_pred):
ans_str_lst.append(f'test{idx+1},{y:.4f}')
with open('./submit.txt','w') as fw:
fw.writelines('\n'.join(ans_str_lst))