slvae

#!/usr/bin/env python
# coding: utf-8

import torch
import torch.nn.functional as F
import torch.nn as nn
import argparse
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
from torch.utils.data import DataLoader
from torch.optim import Adam, SGD
import pickle
import warnings
warnings.filterwarnings("ignore",category=DeprecationWarning)
warnings.filterwarnings("ignore",category=UserWarning)
print('Is GPU available? {}\n'.format(torch.cuda.is_available()))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

parser = argparse.ArgumentParser(description="SLVAE")
datasets = ['jazz_SIR', 'jazz_SI', 'cora_ml_SIR', 'cora_ml_SI', 'power_grid_SIR', 'power_grid_SI',
            'karate_SIR', 'karate_SI', 'netscience_SIR', 'netscience_SI']
parser.add_argument("-d", "--dataset", default="karate_SI", type=str,
                    help="one of: {}".format(", ".join(sorted(datasets))))

args = parser.parse_args(args=[])

import scipy.sparse as sp
from typing import List

import numpy as np

class InverseModel(nn.Module):
    def __init__(self, vae_model: nn.Module, gnn_model: nn.Module, propagate: nn.Module):
        super(InverseModel, self).__init__()

        self.vae_model = vae_model
        self.gnn_model = gnn_model
        self.propagate = propagate

        self.reg_params = list(filter(lambda x: x.requires_grad, self.gnn_model.parameters()))

    def forward(self, input_pair, seed_vec):
        device = next(self.gnn_model.parameters()).device
        seed_idx = torch.LongTensor(np.argwhere(seed_vec.cpu().detach().numpy() == 1)).to(device)

        seed_hat, mean, log_var = self.vae_model(input_pair)
        predictions = self.gnn_model(seed_hat)
        predictions = self.propagate(predictions, seed_idx)

        return seed_hat, mean, log_var, predictions

    def loss(self, x, x_hat, mean, log_var, y, y_hat):
        forward_loss = F.mse_loss(y_hat, y)
        reproduction_loss = F.binary_cross_entropy(x_hat, x, reduction='mean')
        # reproduction_loss = F.mse_loss(x_hat, x)
        KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
        total_loss = forward_loss + reproduction_loss + 1e-3 * KLD
        return KLD, reproduction_loss, forward_loss, total_loss


class ForwardModel(nn.Module):
    def __init__(self, gnn_model: nn.Module, propagate: nn.Module):
        super(ForwardModel, self).__init__()
        self.gnn_model = gnn_model
        self.propagate = propagate
        self.relu = nn.ReLU(inplace=True)

        self.reg_params = list(filter(lambda x: x.requires_grad, self.gnn_model.parameters()))

    def forward(self, seed_vec):
        device = next(self.gnn_model.parameters()).device
        # seed_idx = torch.LongTensor(np.argwhere(seed_vec.cpu().detach().numpy() == 1)).to(device)
        seed_idx = (seed_vec == 1).nonzero(as_tuple=False)

        predictions = self.gnn_model(seed_vec)
        predictions = self.propagate(predictions, seed_idx)

        # predictions = (predictions + seed_vec)/2

        predictions = self.relu(predictions)

        return predictions

    def loss(self, y, y_hat):
        forward_loss = F.mse_loss(y_hat, y)
        return forward_loss
class GNNModel(nn.Module):
    def __init__(self, input_dim, hiddenunits: List[int], num_classes, prob_matrix, bias=True, drop_prob=0.5):
        super(GNNModel, self).__init__()

        self.input_dim = input_dim

        if sp.isspmatrix(prob_matrix):
            prob_matrix = prob_matrix.toarray()
        # requires_grad=False 不参与训练
        self.prob_matrix = nn.Parameter((torch.FloatTensor(prob_matrix)), requires_grad=False)

        fcs = [nn.Linear(input_dim, hiddenunits[0], bias=bias)]
        for i in range(1, len(hiddenunits)):
            fcs.append(nn.Linear(hiddenunits[i - 1], hiddenunits[i]))
        fcs.append(nn.Linear(hiddenunits[-1], num_classes))

        self.fcs = nn.ModuleList(fcs)

        if drop_prob is 0:
            self.dropout = lambda x: x
        else:
            self.dropout = nn.Dropout(drop_prob)

        self.act_fn = nn.ReLU()

    def forward(self, seed_vec):

        for i in range(self.input_dim - 1):
            if i == 0:
                mat = self.prob_matrix.T @ seed_vec.T
                attr_mat = torch.cat((seed_vec.T.unsqueeze(0), mat.unsqueeze(0)), 0)
            else:
                mat = self.prob_matrix.T @ attr_mat[-1]
                attr_mat = torch.cat((attr_mat, mat.unsqueeze(0)), 0)

        layer_inner = self.act_fn(self.fcs[0](self.dropout(attr_mat.T)))
        for fc in self.fcs[1:-1]:
            layer_inner = self.act_fn(fc(layer_inner))
        res = torch.sigmoid(self.fcs[-1](self.dropout(layer_inner)))
        return res

    def loss(self, y, y_hat):
        forward_loss = F.mse_loss(y_hat, y)
        return forward_loss


class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        self.FC_input = nn.Linear(input_dim, hidden_dim)
        self.FC_input2 = nn.Linear(hidden_dim, hidden_dim)
        self.FC_mean = nn.Linear(hidden_dim, latent_dim)
        self.FC_var = nn.Linear(hidden_dim, latent_dim)

        self.bn = nn.BatchNorm1d(num_features=latent_dim)

    def forward(self, x):
        h_ = F.relu(self.FC_input(x))
        h_ = F.relu(self.FC_input2(h_))
        h_ = F.relu(self.FC_input2(h_))
        # 算p(Z|X)的均值和方差
        mean = self.FC_mean(h_)
        log_var = self.FC_var(h_)
        return mean, log_var


class Decoder(nn.Module):
    def __init__(self, input_dim, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.FC_input = nn.Linear(input_dim, latent_dim)
        self.FC_hidden_1 = nn.Linear(latent_dim, hidden_dim)
        self.FC_hidden_2 = nn.Linear(hidden_dim, hidden_dim)
        self.FC_output = nn.Linear(hidden_dim, output_dim)

        # self.prelu = nn.PReLU()

    def forward(self, x):
        h = F.relu(self.FC_input(x))
        h = F.relu(self.FC_hidden_1(h))
        h = F.relu(self.FC_hidden_2(h))
        x_hat = torch.sigmoid(self.FC_output(h))
        return x_hat


class VAEModel(nn.Module):
    def __init__(self, Encoder, Decoder):
        super(VAEModel, self).__init__()
        self.Encoder = Encoder
        self.Decoder = Decoder

    def reparameterization(self, mean, var):
        std = torch.exp(0.5 * var)  # standard deviation
        epsilon = torch.randn_like(var)
        return mean + std * epsilon

    def forward(self, x, adj=None):
        if adj != None:
            mean, log_var = self.Encoder(x, adj)
        else:
            mean, log_var = self.Encoder(x)
        z = self.reparameterization(mean, log_var)  # takes exponential function (log var -> var)
        x_hat = self.Decoder(z)

        return x_hat, mean, log_var


class DiffusionPropagate(nn.Module):
    def __init__(self, prob_matrix, niter):
        super(DiffusionPropagate, self).__init__()

        self.niter = niter

        if sp.isspmatrix(prob_matrix):
            prob_matrix = prob_matrix.toarray()

        self.register_buffer('prob_matrix', torch.FloatTensor(prob_matrix))

    def forward(self, preds, seed_idx):
        # import ipdb; ipdb.set_trace()
        # prop_preds = torch.ones((preds.shape[0], preds.shape[1])).to(device)
        device = preds.device

        for i in range(preds.shape[0]):
            prop_pred = preds[i]
            for j in range(self.niter):
                P2 = self.prob_matrix.T * prop_pred.view((1, -1)).expand(self.prob_matrix.shape)
                P3 = torch.ones(self.prob_matrix.shape).to(device) - P2
                prop_pred = torch.ones((self.prob_matrix.shape[0],)).to(device) - torch.prod(P3, dim=1)
                # prop_pred[seed_idx[seed_idx[:,0] == i][:, 1]] = 1
                prop_pred = prop_pred.unsqueeze(0)
            if i == 0:
                prop_preds = prop_pred
            else:
                prop_preds = torch.cat((prop_preds, prop_pred), 0)

        return prop_preds
with open('../data/' + args.dataset + '.SG', 'rb') as f:
    graph = pickle.load(f)

adj, inverse_pairs, prob_matrix = graph['adj'].toarray(), graph['inverse_pairs'], graph['prob'].toarray()

batch_size = 1

train_set, test_set = torch.utils.data.random_split(inverse_pairs,
                                                    [len(inverse_pairs) - batch_size,
                                                     batch_size])

encoder = Encoder(input_dim=inverse_pairs.shape[2], hidden_dim=512, latent_dim=256)
decoder = Decoder(input_dim=256, latent_dim=512, hidden_dim=256, output_dim=inverse_pairs.shape[2])
vae_model = VAEModel(Encoder=encoder, Decoder=decoder)

gnn_model = GNNModel(input_dim=5,
                     hiddenunits=[64, 64],
                     num_classes=1,
                     prob_matrix=prob_matrix)

propagate = DiffusionPropagate(prob_matrix, niter=2)

model = InverseModel(vae_model, gnn_model, propagate).to(device)


def loss_all(x, x_hat, log_var, mean, y_hat, y):
    forward_loss = F.mse_loss(y_hat, y, reduction='sum')
    monotone_loss = torch.sum(torch.relu(y_hat - y_hat[0])) #信息传播的单调性约束
    reproduction_loss = F.binary_cross_entropy(x_hat, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
    total_loss = reproduction_loss + KLD + forward_loss + monotone_loss
    return reproduction_loss, KLD, total_loss


optimizer = Adam(model.parameters(), lr=2e-3)

model = model.to(device)
model.train()
sample_number = train_set[:].shape[0] * train_set[:].shape[1]

for epoch in range(100):
    re_overall = 0
    kld_overall = 0
    total_overall = 0
    precision_all = 0
    recall_all = 0

    for batch_idx, data_pair in enumerate(train_set):
        # input_pair = torch.cat((data_pair[:, :, 0], data_pair[:, :, 1]), 1).to(device)
        x = data_pair[:, :, 0].float().to(device)
        y = data_pair[:, :, 1].to(device)

        optimizer.zero_grad()

        x_true = x.cpu().detach()

        x_hat, mean, log_var, y_hat = model(x, x)

        re_loss, kld, loss = loss_all(x, x_hat, log_var, mean, y_hat, y)

        x_pred = x_hat.cpu().detach()

        kld_overall += kld.item() * x_hat.size(0)
        re_overall += re_loss.item() * x_hat.size(0)
        total_overall += loss.item() * x_hat.size(0)

        for i in range(x_true.shape[0]):
            x_pred[i][x_pred[i] > 0.55] = 1
            x_pred[i][x_pred[i] != 1] = 0
            precision_all += precision_score(x_true[i].cpu().detach().numpy(), x_pred[i].cpu().detach().numpy(),
                                             zero_division=0)
            recall_all += recall_score(x_true[i].cpu().detach().numpy(), x_pred[i].cpu().detach().numpy(),
                                       zero_division=0)

        loss.backward()
        optimizer.step()

    print("Epoch: {}".format(epoch + 1),
          "\tReconstruction: {:.4f}".format(re_overall / sample_number),
          "\tKLD: {:.4f}".format(kld_overall / sample_number),
          "\tTotal: {:.4f}".format(total_overall / sample_number),
          "\tPrecision: {:.4f}".format(precision_all / sample_number),
          "\tRecall: {:.4f}".format(recall_all / sample_number),
          )

vae_model = model.vae_model
# forward_model = ForwardModel(model.gnn_model, model.propagate).to(device)
forward_model = ForwardModel(model.gnn_model, model.propagate).to(device)

for param in vae_model.parameters():
    param.requires_grad = False

for param in forward_model.parameters():
    param.requires_grad = False

encoder = vae_model.Encoder
decoder = vae_model.Decoder


def loss_seed_x(x, x_hat, loss_type='mse'):
    if loss_type == 'bce':
        return F.binary_cross_entropy(x_hat, x, reduction='mean')
    else:
        return F.mse_loss(x_hat, x)


def loss_inverse(y_true, y_hat, x_hat, f_z_all, BN):
    forward_loss = F.mse_loss(y_hat, y_true)

    log_pmf = []
    for f_z in f_z_all:
        log_likelihood_sum = torch.zeros(1).to(device)
        for i, x_i in enumerate(x_hat[0]):
            temp = torch.pow(f_z[i], x_i) * torch.pow(1 - f_z[i], 1 - x_i).to(torch.double)
            log_likelihood_sum += torch.log(temp)
        log_pmf.append(log_likelihood_sum)

    log_pmf = torch.stack(log_pmf)
    log_pmf = BN(log_pmf.float())

    pmf_max = torch.max(log_pmf)

    pdf_sum = pmf_max + torch.logsumexp(log_pmf - pmf_max, dim=0)

    return forward_loss - pdf_sum, forward_loss


def loss_inverse_initial(y_true, y_hat, x_hat, f_z):
    # print(y_hat.shape,y_true.shape)
    forward_loss = F.mse_loss(y_hat, y_true)

    pdf_sum = 0

    for i, x_i in enumerate(x_hat[0]):
        temp = torch.pow(f_z[i], x_i) * torch.pow(1 - f_z[i], 1 - x_i).to(torch.double)
        pdf_sum += torch.log(temp)

    return forward_loss - pdf_sum, pdf_sum


def x_hat_initialization(model, x_hat, x_true, x, y_true, f_z_bar, test_id, threshold=0.45, lr=1e-3, epochs=100):
    input_optimizer = Adam([x_hat], lr=lr)

    initial_x, initial_x_f1 = [], []

    for epoch in range(epochs):
        input_optimizer.zero_grad()

        y_hat = model(x_hat)

        loss, pdf_loss = loss_inverse_initial(y_true, y_hat, x_hat, f_z_bar)

        x_pred = x_hat.clone().cpu().detach().numpy()
        # x = x_true.cpu().detach().numpy()

        x_pred[x_pred > threshold] = 1
        x_pred[x_pred != 1] = 0
        precision = precision_score(x[0], x_pred[0])
        recall = recall_score(x[0], x_pred[0])
        f1 = f1_score(x[0], x_pred[0])

        loss.backward()
        input_optimizer.step()

        with torch.no_grad():
            x_hat.clamp_(0, 1)

        initial_x.append(x_hat)
        initial_x_f1.append(f1)

    return initial_x, initial_x_f1


x_comparison = {}


for test_id, test in enumerate(test_set):
    precision_all = 0
    recall_all = 0
    f1_all = 0
    auc_all = 0
    for i in range(test.shape[0]):
        train_x = torch.tensor(train_set[:][:, 0, :, :][:, :, 0]).float().to(device)
        train_y = torch.tensor(train_set[:][:, 0, :, :][:, :, 1]).float().to(device)
        x_true = torch.tensor(test[i, :, 0]).float().unsqueeze(0).to(device)
        x_true = x_true.unsqueeze(-1)
        y_true = torch.tensor(test[i, :, 1]).float().unsqueeze(0).to(device)

        # print(x_input.shape)
        with torch.no_grad():
            mean, var = encoder(train_x)
            z_all = vae_model.reparameterization(mean, var)
            # Getting \bar z from all the z's    
            z_bar = torch.mean(z_all, dim=0)

            f_z_all = decoder(z_all)
            f_z_bar = decoder(z_bar)

            x_hat = torch.sigmoid(torch.randn(f_z_all[:1].shape)).unsqueeze(-1).to(device)

            # x_hat = torch.bernoulli(x_hat)

        x_hat.requires_grad = True
        x = x_true.cpu().detach().numpy()
        # initialization
        # model.state_dict()[name][:] += torch.rand(para.size()).cuda() * noise_lambda * torch.std(para)
        print("Getting initialization")
        initial_x, initial_x_prec = x_hat_initialization(forward_model, x_hat, x_true, x, y_true,
                                                         f_z_bar, test_id, threshold=0.3,
                                                         lr=5e-2, epochs=20)

        with torch.no_grad():
            #             init_x = torch.sigmoid(initial_x[initial_x_prec.index(max(initial_x_prec))])
            init_x = initial_x[initial_x_prec.index(max(initial_x_prec))]
        # init_x = torch.bernoulli(init_x)

        init_x.requires_grad = True

        input_optimizer = Adam([init_x], lr=1e-1)
        BN = nn.BatchNorm1d(1, affine=False).to(device)

        print("Inference Starting...")
        for epoch in range(5):
            input_optimizer.zero_grad()
            y_hat = forward_model(init_x)
            loss, forward_loss = loss_inverse(y_true, y_hat, init_x, f_z_all, BN)

            x_pred = init_x.clone().cpu().detach().numpy()

            auc = roc_auc_score(x[0], x_pred[0])

            x_pred[x_pred > 0.55] = 1
            x_pred[x_pred != 1] = 0
            precision = precision_score(x[0], x_pred[0])
            recall = recall_score(x[0], x_pred[0])
            f1 = f1_score(x[0], x_pred[0])
            accuracy = accuracy_score(x[0], x_pred[0])
            loss.backward()
            # print("loss.grad:", torch.sum(init_x.grad))
            input_optimizer.step()

            with torch.no_grad():
                init_x.clamp_(0, 1)
            print("Test #{} Epoch: {:2d}".format(i + 1, epoch + 1),
                  "\tTotal Loss: {:.5f}".format(loss.item()),
                  "\tx Loss: {:.5f}".format(loss_seed_x(x_true, init_x, loss_type='bce').item()),
                  "\tPrec: {:.5f}".format(precision),
                  "\tRec: {:.5f}".format(recall),
                  "\tF1: {:.5f}".format(f1),
                  "\tAUC: {:.5f}".format(auc),
                  "\tACC {:.5f}".format(accuracy)
                  )

        precision_all += precision
        recall_all += recall
        f1_all += f1
        auc_all += auc

    print("Test finished",
          "\tTotal Prec: {:.5f}".format(precision_all / test.shape[0]),
          "\tTotal Rec: {:.5f}".format(recall_all / test.shape[0]),
          "\tTotal F1: {:.5f}".format(f1_all / test.shape[0]),
          "\tTotal AUC: {:.5f}".format(auc_all / test.shape[0])
          )

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值