Datawhale AI夏令营学习笔记:官方发布Baseline代码学习

依赖安装

官方Baseline代码依赖以下环境:

  • Python3
  • pandas
  • scikit-learn
  • rdkit

安装这些依赖的命令如下:

!pip install pandas
!pip install -U scikit-learn
!pip install rdkit

安装成功的提示信息如下:

Looking in indexes: https://mirrors.cloud.aliyuncs.com/pypi/simple
Requirement already satisfied: pandas in /usr/local/lib/python3.10/site-packages (2.2.2)
Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/site-packages (from pandas) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/site-packages (from pandas) (2024.1)
Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/site-packages (from pandas) (2024.1)
Requirement already satisfied: numpy>=1.22.4 in /usr/local/lib/python3.10/site-packages (from pandas) (1.26.4)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0)

Looking in indexes: https://mirrors.cloud.aliyuncs.com/pypi/simple
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/site-packages (1.5.0)
Collecting scikit-learn
  Downloading https://mirrors.cloud.aliyuncs.com/pypi/packages/f2/60/6c589c91e474721efdcec82ea9cc5c743359e52637e46c364ee5236666ef/scikit_learn-1.5.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.4/13.4 MB 79.6 MB/s eta 0:00:0000:010:01
Requirement already satisfied: scipy>=1.6.0 in /usr/local/lib/python3.10/site-packages (from scikit-learn) (1.12.0)
Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.10/site-packages (from scikit-learn) (1.4.2)
Requirement already satisfied: numpy>=1.19.5 in /usr/local/lib/python3.10/site-packages (from scikit-learn) (1.26.4)
Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.10/site-packages (from scikit-learn) (3.5.0)
Installing collected packages: scikit-learn
  Attempting uninstall: scikit-learn
    Found existing installation: scikit-learn 1.5.0
    Uninstalling scikit-learn-1.5.0:
      Successfully uninstalled scikit-learn-1.5.0

导入库

导入所需库:

import pickle
import pandas as pd
from tqdm import tqdm
from sklearn.ensemble import RandomForestRegressor
from rdkit.Chem import rdMolDescriptors
from rdkit import RDLogger, Chem
import numpy as np
RDLogger.DisableLog('rdApp.*')

特征提取

数据格式:

  • rxnid: 数据的ID标识,无实际意义
  • Reactant1: 反应物1
  • Reactant2: 反应物2
  • Product: 产物
  • Additive: 添加剂
  • Solvent: 溶剂
  • Yield: 产率

化学分子的SMILES表达式用于提取分子指纹(向量)。

SMILES (Simplified Molecular Input Line Entry System) 是一种将化学分子用ASCII字符表示的方法。

Morgan指纹 是位向量(bit vector)形式的特征,由0和1组成的向量。

RDKit 是化学信息学的主要工具,开源并支持多平台。

特征提取函数:

def mfgen(mol, nBits=2048, radius=2):
    fp = rdMolDescriptors.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=nBits)
    return np.array(list(map(eval, list(fp.ToBitString()))))

def vec_cpd_lst(smi_lst):
    smi_set = list(set(smi_lst))
    smi_vec_map = {}
    for smi in tqdm(smi_set):
        mol = Chem.MolFromSmiles(smi)
        smi_vec_map[smi] = mfgen(mol)
    smi_vec_map[''] = np.zeros(2048)
    
    vec_lst = [smi_vec_map[smi] for smi in smi_lst]
    return np.array(vec_lst)

数据集加载和向量化

加载训练集和测试集:

dataset_dir = '../dataset'   # 若在AI Studio上,将此路径改为'dataset'

train_df = pd.read_csv(f'{dataset_dir}/round1_train_data.csv')
test_df = pd.read_csv(f'{dataset_dir}/round1_test_data.csv')

print(f'Training set size: {len(train_df)}, test set size: {len(test_df)}')

将SMILES转化为分子指纹:

train_rct1_smi = train_df['Reactant1'].to_list()
train_rct2_smi = train_df['Reactant2'].to_list()
train_add_smi = train_df['Additive'].to_list()
train_sol_smi = train_df['Solvent'].to_list()

train_rct1_fp = vec_cpd_lst(train_rct1_smi)
train_rct2_fp = vec_cpd_lst(train_rct2_smi)
train_add_fp = vec_cpd_lst(train_add_smi)
train_sol_fp = vec_cpd_lst(train_sol_smi)

train_x = np.concatenate([train_rct1_fp, train_rct2_fp, train_add_fp, train_sol_fp], axis=1)
train_y = train_df['Yield'].to_numpy()

test_rct1_smi = test_df['Reactant1'].to_list()
test_rct2_smi = test_df['Reactant2'].to_list()
test_add_smi = test_df['Additive'].to_list()
test_sol_smi = test_df['Solvent'].to_list()

test_rct1_fp = vec_cpd_lst(test_rct1_smi)
test_rct2_fp = vec_cpd_lst(test_rct2_smi)
test_add_fp = vec_cpd_lst(test_add_smi)
test_sol_fp = vec_cpd_lst(test_sol_smi)
test_x = np.concatenate([test_rct1_fp, test_rct2_fp, test_add_fp, test_sol_fp], axis=1)

模型训练和保存

使用随机森林进行建模:

model = RandomForestRegressor(n_estimators=10, max_depth=10, min_samples_split=2, min_samples_leaf=1, n_jobs=-1)
model.fit(train_x, train_y)

保存模型:

with open('./random_forest_model.pkl', 'wb') as file:
    pickle.dump(model, file)

模型加载和预测

加载模型:

with open('random_forest_model.pkl', 'rb') as file:
    loaded_model = pickle.load(file)

预测结果:

test_pred = loaded_model.predict(test_x)

生成提交文件

生成赛题要求的提交文件:

ans_str_lst = ['rxnid,Yield']
for idx, y in enumerate(test_pred):
    ans_str_lst.append(f'test{idx+1},{y:.4f}')
with open('./submit.txt', 'w') as fw:
    fw.writelines('\n'.join(ans_str_lst))

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值