DataWhale图网络学习(七)超大规模数据集类的创建
1 构建数据集
这里我们使用PCQM4M-LSC数据集,它是一个分子图的量子特性回归数据集,包含了3,803,453个图。我们通过自定义数据集类来导入数据集:
import os
import os.path as osp
import pandas as pd
import torch
from ogb.utils import smiles2graph
from ogb.utils.torch_util import replace_numpy_with_torchtensor
from ogb.utils.url import download_url, extract_zip
from rdkit import RDLogger
from torch_geometric.data import Data, Dataset
import shutil
RDLogger.DisableLog('rdApp.*')
class MyPCQM4MDataset(Dataset):
def __init__(self, root):
self.url = 'https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/pcqm4m_kddcup2021.zip'
super(MyPCQM4MDataset, self).__init__(root)
filepath = osp.join(root, 'raw/data.csv.gz')
data_df = pd.read_csv(filepath)
self.smiles_list = data_df['smiles']
self.homolumogap_list = data_df['homolumogap']
@property
def raw_file_names(self):
return 'data.csv.gz'
def download(self):
path = download_url(self.url, self.root)
extract_zip(path, self.root)
os.unlink(path)
shutil.move(osp.join(self.root, 'pcqm4m_kddcup2021/raw/data.csv.gz'), osp.join(self.root, 'raw/data.csv.gz'))
def len(self):
return len(self.smiles_list)
def get(self, idx):
smiles, homolumogap = self.smiles_list[idx], self.homolumogap_list[idx]
graph = smiles2graph(smiles)
assert(len(graph['edge_feat']) == graph['edge_index'].shape[1])
assert(len(graph['node_feat']) == graph['num_nodes'])
x = torch.from_numpy(graph['node_feat']).to(torch.int64)
edge_index = torch.from_numpy(graph['edge_index']).to(torch.int64)
edge_attr = torch.from_numpy(graph['edge_feat']).to(torch.int64)
y = torch.Tensor([homolumogap])
num_nodes = int(graph['num_nodes'])
data = Data(x, edge_index, edge_attr, y, num_nodes=num_nodes)
return data
def get_idx_split(self):
split_dict = replace_numpy_with_torchtensor(torch.load(osp.join(self.root, 'pcqm4m_kddcup2021/split_dict.pt')))
return split_dict
之后我们可以通过load该数据集分批读入数据进行训练和预测。
2 构建模型
我们要对图进行表征学习,首先对节点进行表征,然后通过图池化得到图的表征学习,这部分内容可参考图的表征学习。
2.1 图表征学习模块
首先看图表征学习忽略节点嵌入过程:
import torch
from torch import nn
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
from gin_node import GINNodeEmbedding
class GINGraphPooling(nn.Module):
def __init__(self, num_tasks=1, num_layers=5, emb_dim=300, residual=False, drop_ratio=0, JK="last", graph_pooling="sum"):
"""GIN Graph Pooling Module
此模块首先采用GINNodeEmbedding模块对图上每一个节点做嵌入,然后对节点嵌入做池化得到图的嵌入,最后用一层线性变换得到图的最终的表示(graph representation)。
Args:
num_tasks (int, optional): number of labels to be predicted. Defaults to 1 (控制了图表示的维度,dimension of graph representation).
num_layers (int, optional): number of GINConv layers. Defaults to 5.
emb_dim (int, optional): dimension of node embedding. Defaults to 300.
residual (bool, optional): adding residual connection or not. Defaults to False.
drop_ratio (float, optional): dropout rate. Defaults to 0.
JK (str, optional): 可选的值为"last"和"sum"。选"last",只取最后一层的结点的嵌入,选"sum"对各层的结点的嵌入求和。Defaults to "last".
graph_pooling (str, optional): pooling method of node embedding. 可选的值为"sum","mean","max","attention"和"set2set"。 Defaults to "sum".
Out:
graph representation
"""
super(GINGraphPooling, self).__init__()
self.num_layers = num_layers
self.drop_ratio = drop_ratio
self.JK = JK
self.emb_dim = emb_dim
self.num_tasks = num_tasks
if self.num_layers < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
self.gnn_node = GINNodeEmbedding(num_layers, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual)
# Pooling function to generate whole-graph embeddings
if graph_pooling == "sum":
self.pool = global_add_pool
elif graph_pooling == "mean":
self.pool = global_mean_pool
elif graph_pooling == "max":
self.pool = global_max_pool
elif graph_pooling == "attention":
self.pool = GlobalAttention(gate_nn=nn.Sequential(
nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, 1)))
elif graph_pooling == "set2set":
self.pool = Set2Set(emb_dim, processing_steps=2)
else:
raise ValueError("Invalid graph pooling type.")
if graph_pooling == "set2set":
self.graph_pred_linear = nn.Linear(2*self.emb_dim, self.num_tasks)
else:
self.graph_pred_linear = nn.Linear(self.emb_dim, self.num_tasks)
def forward(self, batched_data):
h_node = self.gnn_node(batched_data)
h_graph = self.pool(h_node, batched_data.batch)
output = self.graph_pred_linear(h_graph)
if self.training:
return output
else:
# At inference time, relu is applied to output to ensure positivity
return torch.clamp(output, min=0, max=50)
2.2 节点嵌入模块
然后我们需要节点嵌入模块:
import torch
from mol_encoder import AtomEncoder
from gin_conv import GINConv
import torch.nn.functional as F
# GNN to generate node embedding
class GINNodeEmbedding(torch.nn.Module):
"""
Output:
node representations
"""
def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK="last", residual=False):
"""GIN Node Embedding Module
采用多层GINConv实现图上结点的嵌入。
"""
super(GINNodeEmbedding, self).__init__()
self.num_layers = num_layers
self.drop_ratio = drop_ratio
self.JK = JK
# add residual connection or not
self.residual = residual
if self.num_layers < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
self.atom_encoder = AtomEncoder(emb_dim)
# List of GNNs
self.convs = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
for layer in range(num_layers):
self.convs.append(GINConv(emb_dim))
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
def forward(self, batched_data):
x, edge_index, edge_attr = batched_data.x, batched_data.edge_index, batched_data.edge_attr
# computing input node embedding
h_list = [self.atom_encoder(x)] # 先将类别型原子属性转化为原子嵌入
for layer in range(self.num_layers):
h = self.convs[layer](h_list[layer], edge_index, edge_attr)
h = self.batch_norms[layer](h)
if layer == self.num_layers - 1:
# remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training=self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
if self.residual:
h += h_list[layer]
h_list.append(h)
# Different implementations of Jk-concat
if self.JK == "last":
node_representation = h_list[-1]
elif self.JK == "sum":
node_representation = 0
for layer in range(self.num_layers + 1):
node_representation += h_list[layer]
return node_representation
2.3 GinConv
模块
在做节点嵌入时,我们是通过GinConv
模块进行节点表征学习的:
import torch
from torch import nn
from torch_geometric.nn import MessagePassing
import torch.nn.functional as F
from ogb.graphproppred.mol_encoder import BondEncoder
### GIN convolution along the graph structure
class GINConv(MessagePassing):
def __init__(self, emb_dim):
'''
emb_dim (int): node embedding dimensionality
'''
super(GINConv, self).__init__(aggr = "add")
self.mlp = nn.Sequential(nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(), nn.Linear(emb_dim, emb_dim))
self.eps = nn.Parameter(torch.Tensor([0]))
self.bond_encoder = BondEncoder(emb_dim = emb_dim)
def forward(self, x, edge_index, edge_attr):
edge_embedding = self.bond_encoder(edge_attr) # 先将类别型边属性转换为边嵌入
out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
return out
def message(self, x_j, edge_attr):
return F.relu(x_j + edge_attr)
def update(self, aggr_out):
return aggr_out
2.4 节点嵌入维度模块
在进行节点表征学习中,我们需要将节点表征嵌入维度:
import torch
from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims
full_atom_feature_dims = get_atom_feature_dims()
full_bond_feature_dims = get_bond_feature_dims()
class AtomEncoder(torch.nn.Module):
"""该类用于对原子属性做嵌入。
记`N`为原子属性的维度,则原子属性表示为`[x1, x2, ..., xi, xN]`,其中任意的一维度`xi`都是类别型数据。full_atom_feature_dims[i]存储了原子属性`xi`的类别数量。
该类将任意的原子属性`[x1, x2, ..., xi, xN]`转换为原子的嵌入`x_embedding`(维度为emb_dim)。
"""
def __init__(self, emb_dim):
super(AtomEncoder, self).__init__()
self.atom_embedding_list = torch.nn.ModuleList()
for i, dim in enumerate(full_atom_feature_dims):
emb = torch.nn.Embedding(dim, emb_dim) # 不同维度的属性用不同的Embedding方法
torch.nn.init.xavier_uniform_(emb.weight.data)
self.atom_embedding_list.append(emb)
def forward(self, x):
x_embedding = 0
for i in range(x.shape[1]):
x_embedding += self.atom_embedding_list[i](x[:,i])
return x_embedding
class BondEncoder(torch.nn.Module):
def __init__(self, emb_dim):
super(BondEncoder, self).__init__()
self.bond_embedding_list = torch.nn.ModuleList()
for i, dim in enumerate(full_bond_feature_dims):
emb = torch.nn.Embedding(dim, emb_dim)
torch.nn.init.xavier_uniform_(emb.weight.data)
self.bond_embedding_list.append(emb)
def forward(self, edge_attr):
bond_embedding = 0
for i in range(edge_attr.shape[1]):
bond_embedding += self.bond_embedding_list[i](edge_attr[:,i])
return bond_embedding
3 模型训练与节点预测
最后就是根据定义的数据集和构建的模型进行模型训练和节点预测任务了:
import os
import torch
import argparse
from tqdm import tqdm
from ogb.lsc import PCQM4MEvaluator
from torch_geometric.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from pcqm4m_data import MyPCQM4MDataset
from gin_graph import GINGraphPooling
from torch.utils.tensorboard import SummaryWriter
def parse_args():
parser = argparse.ArgumentParser(description='Graph data miming with GNN')
parser.add_argument('--task_name', type=str, default='GINGraphPooling',
help='task name')
parser.add_argument('--device', type=int, default=0,
help='which gpu to use if any (default: 0)')
parser.add_argument('--num_layers', type=int, default=5,
help='number of GNN message passing layers (default: 5)')
parser.add_argument('--graph_pooling', type=str, default='sum',
help='graph pooling strategy mean or sum (default: sum)')
parser.add_argument('--emb_dim', type=int, default=256,
help='dimensionality of hidden units in GNNs (default: 256)')
parser.add_argument('--drop_ratio', type=float, default=0.,
help='dropout ratio (default: 0.)')
parser.add_argument('--save_test', action='store_true')
parser.add_argument('--batch_size', type=int, default=512,
help='input batch size for training (default: 512)')
parser.add_argument('--epochs', type=int, default=100,
help='number of epochs to train (default: 100)')
parser.add_argument('--weight_decay', type=float, default=0.00001,
help='weight decay')
parser.add_argument('--early_stop', type=int, default=10,
help='early stop (default: 10)')
parser.add_argument('--num_workers', type=int, default=4,
help='number of workers (default: 4)')
parser.add_argument('--dataset_root', type=str, default="dataset",
help='dataset root')
args = parser.parse_args()
return args
def prepartion(args):
save_dir = os.path.join('saves', args.task_name)
if os.path.exists(save_dir):
for idx in range(1000):
if not os.path.exists(save_dir + '=' + str(idx)):
save_dir = save_dir + '=' + str(idx)
break
args.save_dir = save_dir
os.makedirs(args.save_dir, exist_ok=True)
args.device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
args.output_file = open(os.path.join(args.save_dir, 'output'), 'a')
print(args, file=args.output_file, flush=True)
def train(model, device, loader, optimizer, criterion_fn):
model.train()
loss_accum = 0
for step, batch in enumerate(tqdm(loader)):
batch = batch.to(device)
pred = model(batch).view(-1,)
optimizer.zero_grad()
loss = criterion_fn(pred, batch.y)
loss.backward()
optimizer.step()
loss_accum += loss.detach().cpu().item()
return loss_accum / (step + 1)
def eval(model, device, loader, evaluator):
model.eval()
y_true = []
y_pred = []
with torch.no_grad():
for _, batch in enumerate(tqdm(loader)):
batch = batch.to(device)
pred = model(batch).view(-1,)
y_true.append(batch.y.view(pred.shape).detach().cpu())
y_pred.append(pred.detach().cpu())
y_true = torch.cat(y_true, dim=0)
y_pred = torch.cat(y_pred, dim=0)
input_dict = {"y_true": y_true, "y_pred": y_pred}
return evaluator.eval(input_dict)["mae"]
def test(model, device, loader):
model.eval()
y_pred = []
with torch.no_grad():
for _, batch in enumerate(loader):
batch = batch.to(device)
pred = model(batch).view(-1,)
y_pred.append(pred.detach().cpu())
y_pred = torch.cat(y_pred, dim=0)
return y_pred
def main(args):
prepartion(args)
nn_params = {
'num_layers': args.num_layers,
'emb_dim': args.emb_dim,
'drop_ratio': args.drop_ratio,
'graph_pooling': args.graph_pooling
}
# automatic dataloading and splitting
dataset = MyPCQM4MDataset(root=args.dataset_root)
split_idx = dataset.get_idx_split()
train_data = dataset[split_idx['train']]
valid_data = dataset[split_idx['valid']]
test_data = dataset[split_idx['test']]
train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
valid_loader = DataLoader(valid_data, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
test_loader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
# automatic evaluator. takes dataset name as input
evaluator = PCQM4MEvaluator()
criterion_fn = torch.nn.MSELoss()
device = args.device
model = GINGraphPooling(**nn_params).to(device)
num_params = sum(p.numel() for p in model.parameters())
print(f'#Params: {num_params}', file=args.output_file, flush=True)
print(model, file=args.output_file, flush=True)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=args.weight_decay)
scheduler = StepLR(optimizer, step_size=30, gamma=0.25)
writer = SummaryWriter(log_dir=args.save_dir)
not_improved = 0
best_valid_mae = 9999
for epoch in range(1, args.epochs + 1):
print("=====Epoch {}".format(epoch), file=args.output_file, flush=True)
print('Training...', file=args.output_file, flush=True)
train_mae = train(model, device, train_loader, optimizer, criterion_fn)
print('Evaluating...', file=args.output_file, flush=True)
valid_mae = eval(model, device, valid_loader, evaluator)
print({'Train': train_mae, 'Validation': valid_mae}, file=args.output_file, flush=True)
writer.add_scalar('valid/mae', valid_mae, epoch)
writer.add_scalar('train/mae', train_mae, epoch)
if valid_mae < best_valid_mae:
best_valid_mae = valid_mae
if args.save_test:
print('Saving checkpoint...', file=args.output_file, flush=True)
checkpoint = {
'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(), 'best_val_mae': best_valid_mae, 'num_params': num_params
}
torch.save(checkpoint, os.path.join(args.save_dir, 'checkpoint.pt'))
print('Predicting on test data...', file=args.output_file, flush=True)
y_pred = test(model, device, test_loader)
print('Saving test submission file...', file=args.output_file, flush=True)
evaluator.save_test_submission({'y_pred': y_pred}, args.save_dir)
not_improved = 0
else:
not_improved += 1
if not_improved == args.early_stop:
print(f"Have not improved for {not_improved} epoches.", file=args.output_file, flush=True)
break
scheduler.step()
print(f'Best validation MAE so far: {best_valid_mae}', file=args.output_file, flush=True)
writer.close()
args.output_file.close()
if __name__ == "__main__":
args = parse_args()
main(args)
usage: ipykernel_launcher.py [-h] [–task_name TASK_NAME] [–device DEVICE]
[–num_layers NUM_LAYERS]
[–graph_pooling GRAPH_POOLING]
[–emb_dim EMB_DIM] [–drop_ratio DROP_RATIO]
[–save_test] [–batch_size BATCH_SIZE]
[–epochs EPOCHS] [–weight_decay WEIGHT_DECAY]
[–early_stop EARLY_STOP]
[–num_workers NUM_WORKERS]
[–dataset_root DATASET_ROOT]