trainer.py
import torch
from tqdm import tqdm
import torch.optim as optim
from utils.dataset import GraphData
class Trainer:
def __init__(self, args, net, G_data):
self.args = args
self.net = net
self.feat_dim = G_data.feat_dim
self.fold_idx = G_data.fold_idx
self.init(args, G_data.train_gs, G_data.test_gs)
if torch.cuda.is_available():
self.net.cuda()
def init(self, args, train_gs, test_gs):
print('#train: %d, #test: %d' % (len(train_gs), len(test_gs)))
train_data = GraphData(train_gs, self.feat_dim)
test_data = GraphData(test_gs, self.feat_dim)
self.train_d = train_data.loader(self.args.batch, True)
self.test_d = test_data.loader(self.args.batch, False)
self.optimizer = optim.Adam(
self.net.parameters(), lr=self.args.lr, amsgrad=True,
weight_decay=0.0008)
def to_cuda(self, gs):
if torch.cuda.is_available():
if type(gs) == list:
return [g.cuda() for g in gs]
return gs.cuda()
return gs
def run_epoch(self, epoch, data, model, optimizer):
losses, accs, n_samples = [], [], 0
for batch in tqdm(data, desc=str(epoch), unit='b'):
cur_len, gs, hs, ys = batch
gs, hs, ys = map(self.to_cuda, [gs, hs, ys])
loss, acc = model(gs, hs, ys)
losses.append(loss*cur_len)
accs.append(acc*cur_len)
n_samples += cur_len
if optimizer is not None:
optimizer.zero_grad()
loss.backward()
optimizer.step()
avg_loss, avg_acc = sum(losses) / n_samples, sum(accs) / n_samples
return avg_loss.item(), avg_acc.item()
def train(self):
max_acc = 0.0
train_str = 'Train epoch %d: loss %.5f acc %.5f'
test_str = 'Test epoch %d: loss %.5f acc %.5f max %.5f'
line_str = '%d:\t%.5f\n'
for e_id in range(self.args.num_epochs):
self.net.train()
loss, acc = self.run_epoch(
e_id, self.train_d, self.net, self.optimizer)
print(train_str % (e_id, loss, acc))
with torch.no_grad():
self.net.eval()
loss, acc = self.run_epoch(e_id, self.test_d, self.net, None)
max_acc = max(max_acc, acc)
print(test_str % (e_id, loss, acc, max_acc))
with open(self.args.acc_file, 'a+') as f:
f.write(line_str % (self.fold_idx, max_acc))
network.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.ops import GCN, GraphUnet, Initializer, norm_g
class GNet(nn.Module):
def __init__(self, in_dim, n_classes, args):
super(GNet, self).__init__()
self.n_act = getattr(nn, args.act_n)()
self.c_act = getattr(nn, args.act_c)()
self.s_gcn = GCN(in_dim, args.l_dim, self.n_act, args.drop_n)
self.g_unet = GraphUnet(
args.ks, args.l_dim, args.l_dim, args.l_dim, self.n_act,
args.drop_n)
self.out_l_1 = nn.Linear(3*args.l_dim*(args.l_num+1), args.h_dim)
self.out_l_2 = nn.Linear(args.h_dim, n_classes)
self.out_drop = nn.Dropout(p=args.drop_c)
Initializer.weights_init(self)
def forward(self, gs, hs, labels):
hs = self.embed(gs, hs)
logits = self.classify(hs)
return self.metric(logits, labels)
def embed(self, gs, hs):
o_hs = []
for g, h in zip(gs, hs):
h = self.embed_one(g, h)
o_hs.append(h)
hs = torch.stack(o_hs, 0)
return hs
def embed_one(self, g, h):
g = norm_g(g)
h = self.s_gcn(g, h)
hs = self.g_unet(g, h)
h = self.readout(hs)
return h
def readout(self, hs):
h_max = [torch.max(h, 0)[0] for h in hs]
h_sum = [torch.sum(h, 0) for h in hs]
h_mean = [torch.mean(h, 0) for h in hs]
h = torch.cat(h_max + h_sum + h_mean)
return h
def classify(self, h):
h = self.out_drop(h)
h = self.out_l_1(h)
h = self.c_act(h)
h = self.out_drop(h)
h = self.out_l_2(h)
return F.log_softmax(h, dim=1)
def metric(self, logits, labels):
loss = F.nll_loss(logits, labels)
_, preds = torch.max(logits, 1)
acc = torch.mean((preds == labels).float())
return loss, acc
main.py
import argparse
import random
import time
import torch
import numpy as np
from network import GNet
from trainer import Trainer
from utils.data_loader import FileLoader
def get_args():
parser = argparse.ArgumentParser(description='Args for graph predition')
parser.add_argument('-seed', type=int, default=1, help='seed')
parser.add_argument('-data', default='DD', help='data folder name')
parser.add_argument('-fold', type=int, default=1, help='fold (1..10)')
parser.add_argument('-num_epochs', type=int, default=2, help='epochs')
parser.add_argument('-batch', type=int, default=8, help='batch size')
parser.add_argument('-lr', type=float, default=0.001, help='learning rate')
parser.add_argument('-deg_as_tag', type=int, default=0, help='1 or degree')
parser.add_argument('-l_num', type=int, default=3, help='layer num')
parser.add_argument('-h_dim', type=int, default=512, help='hidden dim')
parser.add_argument('-l_dim', type=int, default=48, help='layer dim')
parser.add_argument('-drop_n', type=float, default=0.3, help='drop net')
parser.add_argument('-drop_c', type=float, default=0.2, help='drop output')
parser.add_argument('-act_n', type=str, default='ELU', help='network act')
parser.add_argument('-act_c', type=str, default='ELU', help='output act')
parser.add_argument('-ks', nargs='+', type=float, default='0.9 0.8 0.7')
parser.add_argument('-acc_file', type=str, default='re', help='acc file')
args, _ = parser.parse_known_args()
return args
def set_random(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
def app_run(args, G_data, fold_idx):
G_data.use_fold_data(fold_idx)
net = GNet(G_data.feat_dim, G_data.num_class, args)
trainer = Trainer(args, net, G_data)
trainer.train()
def main():
args = get_args()
print(args)
set_random(args.seed)
start = time.time()
G_data = FileLoader(args).load_data()
print('load data using ------>', time.time()-start)
if args.fold == 0:
for fold_idx in range(10):
print('start training ------> fold', fold_idx+1)
app_run(args, G_data, fold_idx)
else:
print('start training ------> fold', args.fold)
app_run(args, G_data, args.fold-1)
if __name__ == "__main__":
main()
dataloder.py
import torch
from tqdm import tqdm
import networkx as nx
import numpy as np
import torch.nn.functional as F
from sklearn.model_selection import StratifiedKFold
from functools import partial
class G_data(object):
def __init__(self, num_class, feat_dim, g_list):
self.num_class = num_class
self.feat_dim = feat_dim
self.g_list = g_list
self.sep_data()
def sep_data(self, seed=0):
skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed)
labels = [g.label for g in self.g_list]
self.idx_list = list(skf.split(np.zeros(len(labels)), labels))
def use_fold_data(self, fold_idx):
self.fold_idx = fold_idx+1
train_idx, test_idx = self.idx_list[fold_idx]
self.train_gs = [self.g_list[i] for i in train_idx]
self.test_gs = [self.g_list[i] for i in test_idx]
class FileLoader(object):
def __init__(self, args):
self.args = args
def line_genor(self, lines):
for line in lines:
yield line
def gen_graph(self, f, i, label_dict, feat_dict, deg_as_tag):
row = next(f).strip().split()
n, label = [int(w) for w in row]
if label not in label_dict:
label_dict[label] = len(label_dict)
g = nx.Graph()
g.add_nodes_from(list(range(n)))
node_tags = []
for j in range(n):
row = next(f).strip().split()
tmp = int(row[1]) + 2
row = [int(w) for w in row[:tmp]]
if row[0] not in feat_dict:
feat_dict[row[0]] = len(feat_dict)
for k in range(2, len(row)):
if j != row[k]:
g.add_edge(j, row[k])
if len(row) > 2:
node_tags.append(feat_dict[row[0]])
g.label = label
g.remove_nodes_from(list(nx.isolates(g)))
if deg_as_tag:
g.node_tags = list(dict(g.degree).values())
else:
g.node_tags = node_tags
return g
def process_g(self, label_dict, tag2index, tagset, g):
g.label = label_dict[g.label]
g.feas = torch.tensor([tag2index[tag] for tag in g.node_tags])
g.feas = F.one_hot(g.feas, len(tagset))
A = torch.FloatTensor(nx.to_numpy_matrix(g))
g.A = A + torch.eye(g.number_of_nodes())
return g
def load_data(self):
args = self.args
print('loading data ...')
g_list = []
label_dict = {}
feat_dict = {}
with open('data/%s/%s.txt' % (args.data, args.data), 'r') as f:
lines = f.readlines()
f = self.line_genor(lines)
n_g = int(next(f).strip())
for i in tqdm(range(n_g), desc="Create graph", unit='graphs'):
g = self.gen_graph(f, i, label_dict, feat_dict, args.deg_as_tag)
g_list.append(g)
tagset = set([])
for g in g_list:
tagset = tagset.union(set(g.node_tags))
tagset = list(tagset)
tag2index = {tagset[i]: i for i in range(len(tagset))}
f_n = partial(self.process_g, label_dict, tag2index, tagset)
new_g_list = []
for g in tqdm(g_list, desc="Process graph", unit='graphs'):
new_g_list.append(f_n(g))
num_class = len(label_dict)
feat_dim = len(tagset)
print('# classes: %d' % num_class, '# maximum node tag: %d' % feat_dim)
return G_data(num_class, feat_dim, new_g_list)
dataset.py
import random
import torch
class GraphData(object):
def __init__(self, data, feat_dim):
super(GraphData, self).__init__()
self.data = data
self.feat_dim = feat_dim
self.idx = list(range(len(data)))
self.pos = 0
def __reset__(self):
self.pos = 0
if self.shuffle:
random.shuffle(self.idx)
def __len__(self):
return len(self.data) // self.batch + 1
def __getitem__(self, idx):
g = self.data[idx]
return g.A, g.feas.float(), g.label
def __iter__(self):
return self
def __next__(self):
if self.pos >= len(self.data):
self.__reset__()
raise StopIteration
cur_idx = self.idx[self.pos: self.pos+self.batch]
data = [self.__getitem__(idx) for idx in cur_idx]
self.pos += len(cur_idx)
gs, hs, labels = map(list, zip(*data))
return len(gs), gs, hs, torch.LongTensor(labels)
def loader(self, batch, shuffle, *args):
self.batch = batch
self.shuffle = shuffle
if shuffle:
random.shuffle(self.idx)
return self
ops.py
import torch
import torch.nn as nn
import numpy as np
class GraphUnet(nn.Module):
def __init__(self, ks, in_dim, out_dim, dim, act, drop_p):
super(GraphUnet, self).__init__()
self.ks = ks
self.bottom_gcn = GCN(dim, dim, act, drop_p)
self.down_gcns = nn.ModuleList()
self.up_gcns = nn.ModuleList()
self.pools = nn.ModuleList()
self.unpools = nn.ModuleList()
self.l_n = len(ks)
for i in range(self.l_n):
self.down_gcns.append(GCN(dim, dim, act, drop_p))
self.up_gcns.append(GCN(dim, dim, act, drop_p))
self.pools.append(Pool(ks[i], dim, drop_p))
self.unpools.append(Unpool(dim, dim, drop_p))
def forward(self, g, h):
adj_ms = []
indices_list = []
down_outs = []
hs = []
org_h = h
for i in range(self.l_n):
h = self.down_gcns[i](g, h)
adj_ms.append(g)
down_outs.append(h)
g, h, idx = self.pools[i](g, h)
indices_list.append(idx)
h = self.bottom_gcn(g, h)
for i in range(self.l_n):
up_idx = self.l_n - i - 1
g, idx = adj_ms[up_idx], indices_list[up_idx]
g, h = self.unpools[i](g, h, down_outs[up_idx], idx)
h = self.up_gcns[i](g, h)
h = h.add(down_outs[up_idx])
hs.append(h)
h = h.add(org_h)
hs.append(h)
return hs
class GCN(nn.Module):
def __init__(self, in_dim, out_dim, act, p):
super(GCN, self).__init__()
self.proj = nn.Linear(in_dim, out_dim)
self.act = act
self.drop = nn.Dropout(p=p) if p > 0.0 else nn.Identity()
def forward(self, g, h):
h = self.drop(h)
h = torch.matmul(g, h)
h = self.proj(h)
h = self.act(h)
return h
class Pool(nn.Module):
def __init__(self, k, in_dim, p):
super(Pool, self).__init__()
self.k = k
self.sigmoid = nn.Sigmoid()
self.proj = nn.Linear(in_dim, 1)
self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
def forward(self, g, h):
Z = self.drop(h)
weights = self.proj(Z).squeeze()
scores = self.sigmoid(weights)
return top_k_graph(scores, g, h, self.k)
class Unpool(nn.Module):
def __init__(self, *args):
super(Unpool, self).__init__()
def forward(self, g, h, pre_h, idx):
new_h = h.new_zeros([g.shape[0], h.shape[1]])
new_h[idx] = h
return g, new_h
def top_k_graph(scores, g, h, k):
num_nodes = g.shape[0]
values, idx = torch.topk(scores, max(2, int(k*num_nodes)))
new_h = h[idx, :]
values = torch.unsqueeze(values, -1)
new_h = torch.mul(new_h, values)
un_g = g.bool().float()
un_g = torch.matmul(un_g, un_g).bool().float()
un_g = un_g[idx, :]
un_g = un_g[:, idx]
g = norm_g(un_g)
return g, new_h, idx
def norm_g(g):
degrees = torch.sum(g, 1)
g = g / degrees
return g
class Initializer(object):
@classmethod
def _glorot_uniform(cls, w):
if len(w.size()) == 2:
fan_in, fan_out = w.size()
elif len(w.size()) == 3:
fan_in = w.size()[1] * w.size()[2]
fan_out = w.size()[0] * w.size()[2]
else:
fan_in = np.prod(w.size())
fan_out = np.prod(w.size())
limit = np.sqrt(6.0 / (fan_in + fan_out))
w.uniform_(-limit, limit)
@classmethod
def _param_init(cls, m):
if isinstance(m, nn.parameter.Parameter):
cls._glorot_uniform(m.data)
elif isinstance(m, nn.Linear):
m.bias.data.zero_()
cls._glorot_uniform(m.weight.data)
@classmethod
def weights_init(cls, m):
for p in m.modules():
if isinstance(p, nn.ParameterList):
for pp in p:
cls._param_init(pp)
else:
cls._param_init(p)
for name, p in m.named_parameters():
if '.' not in name:
cls._param_init(p)
transform.py
import networkx as nx
from collections import defaultdict
def get_indic(ind_file):
with open(ind_file) as f:
lines = f.readlines()
indic = {}
g_sizes = [0]*int(lines[-1])
for i, line in enumerate(lines):
g_id = int(line)-1
indic[i] = g_id
g_sizes[g_id] += 1
return indic, g_sizes
def get_labels(label_file):
with open(label_file) as f:
lines = f.readlines()
return {i: int(line)-1 for i, line in enumerate(lines)}
def trans_graphs(g_file, A_file, labels, id_dict, g_sizes):
with open(g_file, 'w') as f:
f.write('%s\n' % len(labels))
with open(A_file) as f:
lines = f.readlines()
edges_list = [defaultdict(list) for _ in range(len(labels))]
for line in lines:
i, j = list(map(int, line.split(',')))
g_id = id_dict[i-1]
edges_list[g_id][i-1].append(j-1)
edges_list[g_id][j-1].append(i-1) # use set to remove duplicate
acc_g_size = 0
for i in range(len(labels)):
print('=========> graph', i, acc_g_size)
label = labels[i]
edges = edges_list[i]
write_graph(g_file, acc_g_size, label, edges)
acc_g_size += g_sizes[i]
def write_graph(g_file, acc_g_size, label, edges):
with open(g_file, 'a+') as f:
f.write('%s %s\n' % (len(edges), label))
num_nodes = len(edges)
for i in range(len(edges)):
neighbors = sorted(set(edges[acc_g_size+i]))
neighbors = [n_id - acc_g_size for n_id in neighbors]
neighbors = [n_id for n_id in neighbors if n_id < num_nodes]
f.write('0 %s %s\n' % (
len(neighbors), ' '.join(map(str, neighbors))))
def load_data(g_file):
print('loading data')
label_dict = {}
feat_dict = {}
acc_line_num = 0
with open(g_file, 'r') as f:
n_g = int(f.readline().strip())
acc_line_num += 1
for i in range(n_g):
row = f.readline().strip().split()
acc_line_num += 1
num_node, label = [int(w) for w in row]
if label not in label_dict:
mapped = len(label_dict)
label_dict[label] = mapped
g = nx.Graph()
node_tags = []
n_edges = 0
for j in range(num_node):
g.add_node(j)
row = f.readline().strip().split()
acc_line_num += 1
tmp = int(row[1]) + 2
if tmp == len(row):
row = [int(w) for w in row]
else:
row = [int(w) for w in row[:tmp]]
if not row[0] in feat_dict:
mapped = len(feat_dict)
feat_dict[row[0]] = mapped
node_tags.append(feat_dict[row[0]])
n_edges += row[1]
for k in range(2, len(row)):
g.add_edge(j, row[k])
assert len(g) == num_node
if __name__ == '__main__':
name = 'REDDIT-MULTI-12K'
g_file = '%s.txt' % name
id_dict, g_sizes = get_indic('%s/%s_graph_indicator.txt' % (name, name))
labels = get_labels('%s/%s_graph_labels.txt' % (name, name))
trans_graphs(
g_file, '%s/%s_A.txt' % (name, name), labels,
id_dict, g_sizes)
load_data(g_file)
# print(id_dict)