《大数据+AI在大健康领域中最佳实践前瞻》---- 基于变分自编码器(VAE) 进行疾病预测实现


Using Variational Autoencoders to predict future diagnosis

VAE for Collaborative Filtering

This work is an adaptation of the work by Dawen et.al who used VAEs for the purpose of Collaborative filtering. The work by Dawen et.al exploits the Generative nature of VAEs to arrive at a completed user-preference information given an input of partial user-preference information.

Work by Dawen et.al: https://arxiv.org/abs/1802.05814

Diagnosis Codes

In the healthcare industry, the diagnosis any patient encounters, has been standardized with diagnosis codes. Each disease or a medical condition is mapped to a diagnosis code.
ICD10 Diagnosis Codes: https://www.icd10data.com/

Data

In the data we are using, we have information on a set of patients and the diagnoses that they have undergone. Each of these diagnoses are mapped to diagnosis codes. The data contains a total of 1567 unique diagnosis codes.
So a given patient is represented by a binary vector of dimension 1567 where an element is 1 if that patient has undergone the particular diagnosis and 0 otherwise.

VAE for next diagnosis prediction

Now given the patient diagnosis information, the VAE encodes it into a latent space. It learns the information on distribution of patients and the clusters of diagnoses they undergo.

To provide a simple example if you consider diabetes in older adults, the group of diagnosis that would appear commonly among such adults would be something like, Diabetes, Cholesterol, Blood Pressure, Arthritics etc,.
Now given a patient with a diagnosis set which says something like, Diabetes and Cholesterol, this patient would get mapped to the same space in the latent dimension as the older adults with the diagnosis mentioned earlier.

This mapping of similar patients to similar latent space has a very favourable impact on decoding/reconstruction to original space. On decoding, what happens is that the, missing diagnosis with high probabilities of occurence, for a particular patient is also reconstructed.

This ability to fill in the missing diagnosis in the form of a Collaborative Filtering of sorts is why I apply this technique to predict the next diagnosis.

Applications

This work can be used for many applications ranging from insurance companies using it to better predict a patient’s needs to healthcare applications which encourage people to improve their life-style choices.

Some background

I came across the application of VAEs for Collaborative filtering, when I studied it for my previous work “Hybrid VAE for Collaborative Filtering”. This work processes the movie plot information from IMDb and uses it as an input to improve movie recommendation systems. This particular work was published in RecSys 2018 Knowledge Transfer Learning workshop: https://arxiv.org/abs/1808.01006

IMPLEMENTATION

 %matplotlib inline
import numpy as np
import pickle
import os
from matplotlib import pyplot as plt
from keras.layers import Input, Dense, Lambda, Multiply, Dropout,Embedding, Flatten, Activation, Reshape
from keras.models import Model
from keras import losses
from keras import backend as K
from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, Callback
from IPython.display import clear_output
from sklearn import preprocessing
from keras import regularizers
import keras
import pandas as pd
import numpy as np
Using TensorFlow backend.
import os
os.getcwd()
'C:\\Users\\iz\\Desktop'
df = pd.read_csv('test.csv')
df = df.drop(['KEY'],axis = 1)
df.head()
T40A08I69Z48R44N92R59B97M96I35...H61T84M16J38Z90D68K83Z87Z75Z43
00000000000...0000100100
10000000000...0000100100
20100000000...0000010100
30000001000...0000000100
40000000000...0000000000

5 rows × 500 columns

from sklearn.model_selection import train_test_split

y = range(df.shape[0])
xtrain, xtest, ytrain, ytest = train_test_split(df, y, test_size = 0.1, random_state = 42)
xtrain, xval, ytrain, yval = train_test_split(xtrain, ytrain, test_size = 0.1, random_state = 42)

# import numpy as np

with open('./train.data', 'wb') as f:
    np.save(f, xtrain)
with open('./test.data', 'wb') as f:
    np.save(f, xtest)
with open('./val.data', 'wb') as f:
    np.save(f, xval)
Function to plot Losses
class PlotLosses(Callback):
    def on_train_begin(self, logs={}):
        self.i = 0
        self.x = []
        self.losses = []
        self.val_losses = []        
        self.fig = plt.figure()
        self.logs = []

    def on_epoch_end(self, epoch, logs={}):
        self.logs.append(logs)
        self.x.append(self.i)
        self.losses.append(logs.get('loss'))
        self.val_losses.append(logs.get('val_loss'))
        self.i += 1
        
        clear_output(wait=True)
        plt.plot(self.x, self.losses, label="loss")
        plt.plot(self.x, self.val_losses, label="val_loss")
        plt.legend()
        plt.show();
        
plot_losses = PlotLosses()
Load Data
with open('train.data', 'rb') as f:
    x_train = np.load(f)
print("number of training users: ", x_train.shape[0])

with open('val.data', 'rb') as f:
    x_val = np.load(f)
print("number of validation users: ", x_val.shape[0])
number of training users:  5234
number of validation users:  582
x_train.shape,x_val.shape
((5234, 500), (582, 500))
x_train[0].shape
(500,)
x_train = x_train[:5000]
x_val = x_val[:500]
Configure Network
# encoder/decoder network size

batch_size=100
original_dim = x_train.shape[1]
intermediate_dim=200
latent_dim=100
nb_epochs=30
epsilon_std=1.0
Using a two output network

Here, we have two outputs from the network, which is much different compared to the original VAE network proposed by Dawen et.al.

The first output reconstructs the given input, while the second output gives out a probability distribution over the Diagnosis codes. Each of them have a specific loss function that maximizes the particular objective.

Network 1: Same objective functions for the two outputs
#Function to increase the relevance of the KL regularization as the training progresses
class increaseBeta(Callback):
    def __init__(self):
        self.global_beta = 0.0
    def on_train_begin(self, logs={}):
        self.global_beta = 0.0
    def on_epoch_end(self, epoch, logs={}):
        self.global_beta = self.global_beta + 0.01

updateBeta = increaseBeta()

#Function to l2 normalize the inputs
def l2normalize(args):
    _x=args
    return K.l2_normalize(_x, axis = -1)

#Function to do the sampling from Latent Space
def sampling(args):
    _mean,_log_var=args
    epsilon=K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0., stddev=epsilon_std)
    return _mean+K.exp(_log_var/2)*epsilon

# encoder network

x=Input(batch_shape=(batch_size,original_dim))
norm_x = Lambda(l2normalize, output_shape=(original_dim,))(x)
norm_x = Dropout(rate = 0.5)(norm_x)
h=Dense(intermediate_dim, activation='relu')(norm_x)
z_mean=Dense(latent_dim)(h)
z_log_var=Dense(latent_dim)(h)

z= Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])

# decoder network
h_decoder=Dense(intermediate_dim, activation='relu')
x_bar=Dense(original_dim, activation='sigmoid') 
x_prob=Dense(original_dim, activation='softmax')
h_decoded = h_decoder(z)

# We have two outputs, one which reconstructs the given input, the other which reconstructs the probability 
x_decoded = x_bar(h_decoded)
x_probability = x_prob(h_decoded)

def vae_loss(x,x_bar):
    reconst_loss = K.sum(losses.binary_crossentropy(x,x_bar), axis = -1)
    kl_loss = K.sum( 0.5 * (K.exp(z_log_var) - z_log_var + K.square(z_mean) - 1), axis=-1)
    return reconst_loss + (updateBeta.global_beta)*kl_loss

# build and compile model
vae = Model(x, [x_decoded, x_probability])
vae.compile(optimizer='adam', loss=vae_loss, loss_weights=[1., 1.])

weightsPath = "./weights/weights_vae1.hdf5"
x_train,y = [x_train, x_train], batch_size = batch_size, epochs=30,\
#         validation_data=(x_val, [x_val, x_val]), callbacks=[checkpointer, reduce_lr, plot_losses, updateBeta])
vae.fit(x = x_train,y = [x_train, x_train], batch_size = batch_size, epochs=30,\
        validation_data=(x_val, [x_val, x_val]), )# checkpointer = ModelCheckpoint(filepath=weightsPath, verbose=1, save_best_only=True)
# reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.001)

# vae.fit(x = 
Train on 5000 samples, validate on 500 samples
Epoch 1/30
5000/5000 [==============================] - 4s 758us/step - loss: 78.1787 - dense_23_loss: 34.3244 - dense_24_loss: 43.8543 - val_loss: 64.6742 - val_dense_23_loss: 22.1748 - val_dense_24_loss: 42.4994
Epoch 2/30
5000/5000 [==============================] - 4s 743us/step - loss: 64.4046 - dense_23_loss: 21.9435 - dense_24_loss: 42.4611 - val_loss: 63.6366 - val_dense_23_loss: 21.5553 - val_dense_24_loss: 42.0814
Epoch 3/30
5000/5000 [==============================] - 4s 837us/step - loss: 63.6365 - dense_23_loss: 21.5137 - dense_24_loss: 42.1228 - val_loss: 62.9220 - val_dense_23_loss: 21.1661 - val_dense_24_loss: 41.7559
Epoch 4/30
5000/5000 [==============================] - 4s 727us/step - loss: 62.5700 - dense_23_loss: 20.9471 - dense_24_loss: 41.6229 - val_loss: 61.4582 - val_dense_23_loss: 20.4014 - val_dense_24_loss: 41.0568
Epoch 5/30
5000/5000 [==============================] - 4s 733us/step - loss: 61.5120 - dense_23_loss: 20.3677 - dense_24_loss: 41.1443 - val_loss: 60.5572 - val_dense_23_loss: 19.8999 - val_dense_24_loss: 40.6573 los
Epoch 6/30
5000/5000 [==============================] - 3s 669us/step - loss: 60.6162 - dense_23_loss: 19.8809 - dense_24_loss: 40.7353 - val_loss: 59.4887 - val_dense_23_loss: 19.3130 - val_dense_24_loss: 40.1757
Epoch 7/30
5000/5000 [==============================] - 4s 879us/step - loss: 59.7845 - dense_23_loss: 19.4161 - dense_24_loss: 40.3683 - val_loss: 58.6560 - val_dense_23_loss: 18.8436 - val_dense_24_loss: 39.8125
Epoch 8/30
5000/5000 [==============================] - 3s 581us/step - loss: 59.0007 - dense_23_loss: 18.9740 - dense_24_loss: 40.0267 - val_loss: 57.7723 - val_dense_23_loss: 18.3500 - val_dense_24_loss: 39.4223
Epoch 9/30
5000/5000 [==============================] - 3s 664us/step - loss: 58.1647 - dense_23_loss: 18.4887 - dense_24_loss: 39.6760 - val_loss: 56.8185 - val_dense_23_loss: 17.8036 - val_dense_24_loss: 39.0149
Epoch 10/30
5000/5000 [==============================] - 4s 705us/step - loss: 57.3740 - dense_23_loss: 18.0296 - dense_24_loss: 39.3444 - val_loss: 56.0715 - val_dense_23_loss: 17.3908 - val_dense_24_loss: 38.6807
Epoch 11/30
5000/5000 [==============================] - 4s 711us/step - loss: 56.7545 - dense_23_loss: 17.6735 - dense_24_loss: 39.0810 - val_loss: 55.1875 - val_dense_23_loss: 16.8682 - val_dense_24_loss: 38.3194
Epoch 12/30
5000/5000 [==============================] - 3s 550us/step - loss: 56.1791 - dense_23_loss: 17.3366 - dense_24_loss: 38.8425 - val_loss: 54.6427 - val_dense_23_loss: 16.5640 - val_dense_24_loss: 38.0787
Epoch 13/30
5000/5000 [==============================] - 3s 568us/step - loss: 55.6814 - dense_23_loss: 17.0418 - dense_24_loss: 38.6397 - val_loss: 54.1010 - val_dense_23_loss: 16.2432 - val_dense_24_loss: 37.8578
Epoch 14/30
 500/5000 [==>...........................] - ETA: 2s - loss: 54.4526 - dense_23_loss: 16.7270 - dense_24_loss: 37.7256

C:\ProgramData\Anaconda3\envs\zhongdian\lib\site-packages\keras\callbacks\callbacks.py:95: RuntimeWarning: Method (on_train_batch_end) is slow compared to the batch update (0.111798). Check your callbacks.
  % (hook_name, delta_t_median), RuntimeWarning)


5000/5000 [==============================] - 3s 624us/step - loss: 55.2011 - dense_23_loss: 16.7642 - dense_24_loss: 38.4370 - val_loss: 53.4904 - val_dense_23_loss: 15.8693 - val_dense_24_loss: 37.6210
Epoch 15/30
5000/5000 [==============================] - ETA: 0s - loss: 54.8870 - dense_23_loss: 16.5693 - dense_24_loss: 38.31 - 3s 699us/step - loss: 54.8549 - dense_23_loss: 16.5582 - dense_24_loss: 38.2967 - val_loss: 53.0989 - val_dense_23_loss: 15.6399 - val_dense_24_loss: 37.4590
Epoch 16/30
5000/5000 [==============================] - 3s 612us/step - loss: 54.5157 - dense_23_loss: 16.3574 - dense_24_loss: 38.1583 - val_loss: 52.5809 - val_dense_23_loss: 15.3208 - val_dense_24_loss: 37.2601
Epoch 17/30
5000/5000 [==============================] - 3s 628us/step - loss: 54.2176 - dense_23_loss: 16.1803 - dense_24_loss: 38.0373 - val_loss: 52.2525 - val_dense_23_loss: 15.1291 - val_dense_24_loss: 37.1233
Epoch 18/30
5000/5000 [==============================] - 4s 744us/step - loss: 53.8684 - dense_23_loss: 15.9648 - dense_24_loss: 37.9036 - val_loss: 51.7984 - val_dense_23_loss: 14.8564 - val_dense_24_loss: 36.9420 - dense_23_
Epoch 19/30
5000/5000 [==============================] - 3s 600us/step - loss: 53.5824 - dense_23_loss: 15.7955 - dense_24_loss: 37.7869 - val_loss: 51.4563 - val_dense_23_loss: 14.6374 - val_dense_24_loss: 36.8189
Epoch 20/30
5000/5000 [==============================] - 3s 534us/step - loss: 53.3213 - dense_23_loss: 15.6395 - dense_24_loss: 37.6818 - val_loss: 51.2304 - val_dense_23_loss: 14.5035 - val_dense_24_loss: 36.7269
Epoch 21/30
5000/5000 [==============================] - 4s 701us/step - loss: 53.0680 - dense_23_loss: 15.4855 - dense_24_loss: 37.5825 - val_loss: 50.9055 - val_dense_23_loss: 14.3052 - val_dense_24_loss: 36.6004
Epoch 22/30
5000/5000 [==============================] - 4s 735us/step - loss: 52.8112 - dense_23_loss: 15.3356 - dense_24_loss: 37.4755 - val_loss: 50.6851 - val_dense_23_loss: 14.1706 - val_dense_24_loss: 36.5145
Epoch 23/30
5000/5000 [==============================] - 4s 733us/step - loss: 52.6116 - dense_23_loss: 15.2092 - dense_24_loss: 37.4025 - val_loss: 50.3594 - val_dense_23_loss: 13.9758 - val_dense_24_loss: 36.3836
Epoch 24/30
5000/5000 [==============================] - 4s 849us/step - loss: 52.4208 - dense_23_loss: 15.0949 - dense_24_loss: 37.3259 - val_loss: 50.1555 - val_dense_23_loss: 13.8413 - val_dense_24_loss: 36.3142
Epoch 25/30
5000/5000 [==============================] - 4s 732us/step - loss: 52.2180 - dense_23_loss: 14.9752 - dense_24_loss: 37.2428 - val_loss: 49.9194 - val_dense_23_loss: 13.6918 - val_dense_24_loss: 36.2276
Epoch 26/30
5000/5000 [==============================] - 3s 604us/step - loss: 52.0623 - dense_23_loss: 14.8808 - dense_24_loss: 37.1815 - val_loss: 49.7312 - val_dense_23_loss: 13.5534 - val_dense_24_loss: 36.1778
Epoch 27/30
5000/5000 [==============================] - 3s 507us/step - loss: 51.8958 - dense_23_loss: 14.7736 - dense_24_loss: 37.1222 - val_loss: 49.5618 - val_dense_23_loss: 13.4641 - val_dense_24_loss: 36.0977
Epoch 28/30
5000/5000 [==============================] - 3s 566us/step - loss: 51.7471 - dense_23_loss: 14.6879 - dense_24_loss: 37.0592 - val_loss: 49.3073 - val_dense_23_loss: 13.3000 - val_dense_24_loss: 36.0073
Epoch 29/30
5000/5000 [==============================] - 3s 595us/step - loss: 51.5263 - dense_23_loss: 14.5482 - dense_24_loss: 36.9781 - val_loss: 49.1972 - val_dense_23_loss: 13.2198 - val_dense_24_loss: 35.9775
Epoch 30/30
5000/5000 [==============================] - 3s 551us/step - loss: 51.4380 - dense_23_loss: 14.4985 - dense_24_loss: 36.9395 - val_loss: 48.9715 - val_dense_23_loss: 13.0922 - val_dense_24_loss: 35.8793





<keras.callbacks.callbacks.History at 0x226a912c6d8>
Network 2: Different objective functions for the two outputs
# Function to increase the relevance of the KL regularization as the training progresses

class increaseBeta(Callback):
    def __init__(self):
        self.global_beta = 0.0
    def on_train_begin(self, logs={}):
        self.global_beta = 0.0
    def on_epoch_end(self, epoch, logs={}):
        self.global_beta = self.global_beta + 0.01

updateBeta = increaseBeta()

#Function to l2 normalize the inputs
def l2normalize(args):
    _x=args
    return K.l2_normalize(_x, axis = -1)

#Function to do the sampling from Latent Space
def sampling(args):
    _mean,_log_var=args
    epsilon=K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0., stddev=epsilon_std)
    return _mean+K.exp(_log_var/2)*epsilon


# encoder network
x=Input(batch_shape=(batch_size,original_dim))
norm_x = Lambda(l2normalize, output_shape=(original_dim,))(x)
norm_x = Dropout(rate = 0.5)(norm_x)
h=Dense(intermediate_dim, activation='relu')(norm_x)
z_mean=Dense(latent_dim)(h)
z_log_var=Dense(latent_dim)(h)

z= Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])

# decoder network
h_decoder=Dense(intermediate_dim, activation='relu')
x_bar=Dense(original_dim, activation='sigmoid') 
x_prob=Dense(original_dim, activation='softmax')
h_decoded = h_decoder(z)
#We have two outputs, one which reconstructs the given input, the other which reconstructs the probability 
x_decoded = x_bar(h_decoded)
x_probability = x_prob(h_decoded)

def vae_loss1(x,x_bar):
    reconst_loss = K.sum(losses.binary_crossentropy(x,x_bar), axis = -1)
    kl_loss = K.sum( 0.5 * (K.exp(z_log_var) - z_log_var + K.square(z_mean) - 1), axis=-1)
    return reconst_loss + (updateBeta.global_beta)*kl_loss

def vae_loss2(x,x_bar):
    neg_ll = -K.sum(x_bar*x, axis = -1)
    kl_loss = K.sum( 0.5 * (K.exp(z_log_var) - z_log_var + K.square(z_mean) - 1), axis=-1)
    return neg_ll + (updateBeta.global_beta)*kl_loss

# build and compile model
vae2 = Model(x, [x_decoded, x_probability])
vae2.compile(optimizer='adam', loss=[vae_loss1, vae_loss2], loss_weights=[0.5, 0.5])

# weightsPath = "./weights/weights_vae2.hdf5"
# checkpointer = ModelCheckpoint(filepath=weightsPath, verbose=1, save_best_only=True)
# reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.001)

# vae2.fit(x = x_train,y = [x_train, x_train], batch_size = batch_size, epochs=30,\
#         validation_data=(x_val, [x_val, x_val]), callbacks=[checkpointer, reduce_lr, plot_losses, updateBeta])
vae2.fit(x = x_train,y = [x_train, x_train], batch_size = batch_size, epochs=30,\
        validation_data=(x_val, [x_val, x_val]), )
Train on 5000 samples, validate on 500 samples
Epoch 1/30
5000/5000 [==============================] - 4s 729us/step - loss: 16.9416 - dense_29_loss: 34.2665 - dense_30_loss: -0.3832 - val_loss: 10.8511 - val_dense_29_loss: 22.3960 - val_dense_30_loss: -0.6938
Epoch 2/30
5000/5000 [==============================] - 3s 596us/step - loss: 10.6567 - dense_29_loss: 22.0604 - dense_30_loss: -0.7470 - val_loss: 10.5063 - val_dense_29_loss: 21.7948 - val_dense_30_loss: -0.7821
Epoch 3/30
5000/5000 [==============================] - 3s 587us/step - loss: 10.4868 - dense_29_loss: 21.7343 - dense_30_loss: -0.7607 - val_loss: 10.3348 - val_dense_29_loss: 21.4568 - val_dense_30_loss: -0.7872
Epoch 4/30
5000/5000 [==============================] - 3s 588us/step - loss: 10.3471 - dense_29_loss: 21.4561 - dense_30_loss: -0.7619 - val_loss: 10.2142 - val_dense_29_loss: 21.2168 - val_dense_30_loss: -0.7883
Epoch 5/30
5000/5000 [==============================] - 3s 583us/step - loss: 10.1537 - dense_29_loss: 21.0695 - dense_30_loss: -0.7622 - val_loss: 9.8815 - val_dense_29_loss: 20.5522 - val_dense_30_loss: -0.7891
Epoch 6/30
5000/5000 [==============================] - 4s 700us/step - loss: 9.8399 - dense_29_loss: 20.4424 - dense_30_loss: -0.7627 - val_loss: 9.6098 - val_dense_29_loss: 20.0090 - val_dense_30_loss: -0.7894
Epoch 7/30
5000/5000 [==============================] - 3s 661us/step - loss: 9.6009 - dense_29_loss: 19.9645 - dense_30_loss: -0.7627 - val_loss: 9.3354 - val_dense_29_loss: 19.4604 - val_dense_30_loss: -0.7896
Epoch 8/30
5000/5000 [==============================] - 4s 732us/step - loss: 9.3988 - dense_29_loss: 19.5603 - dense_30_loss: -0.7628 - val_loss: 9.1374 - val_dense_29_loss: 19.0644 - val_dense_30_loss: -0.7897
Epoch 9/30
5000/5000 [==============================] - 3s 698us/step - loss: 9.1813 - dense_29_loss: 19.1254 - dense_30_loss: -0.7629 - val_loss: 8.9246 - val_dense_29_loss: 18.6389 - val_dense_30_loss: -0.7896
Epoch 10/30
5000/5000 [==============================] - 4s 761us/step - loss: 8.9580 - dense_29_loss: 18.6789 - dense_30_loss: -0.7629 - val_loss: 8.6732 - val_dense_29_loss: 18.1362 - val_dense_30_loss: -0.7898
Epoch 11/30
5000/5000 [==============================] - 3s 695us/step - loss: 8.7604 - dense_29_loss: 18.2836 - dense_30_loss: -0.7629 - val_loss: 8.3996 - val_dense_29_loss: 17.5891 - val_dense_30_loss: -0.7899
Epoch 12/30
5000/5000 [==============================] - 3s 628us/step - loss: 8.5640 - dense_29_loss: 17.8911 - dense_30_loss: -0.7630 - val_loss: 8.1917 - val_dense_29_loss: 17.1734 - val_dense_30_loss: -0.7899
Epoch 13/30
5000/5000 [==============================] - 3s 616us/step - loss: 8.3888 - dense_29_loss: 17.5405 - dense_30_loss: -0.7630 - val_loss: 7.9692 - val_dense_29_loss: 16.7284 - val_dense_30_loss: -0.7900
Epoch 14/30
5000/5000 [==============================] - 3s 638us/step - loss: 8.2512 - dense_29_loss: 17.2653 - dense_30_loss: -0.7630 - val_loss: 7.8081 - val_dense_29_loss: 16.4062 - val_dense_30_loss: -0.7900
Epoch 15/30
5000/5000 [==============================] - 4s 777us/step - loss: 8.1205 - dense_29_loss: 17.0040 - dense_30_loss: -0.7630 - val_loss: 7.6419 - val_dense_29_loss: 16.0737 - val_dense_30_loss: -0.7900
Epoch 16/30
5000/5000 [==============================] - 5s 972us/step - loss: 7.9886 - dense_29_loss: 16.7401 - dense_30_loss: -0.7630 - val_loss: 7.5225 - val_dense_29_loss: 15.8349 - val_dense_30_loss: -0.7900
Epoch 17/30
5000/5000 [==============================] - 4s 724us/step - loss: 7.8793 - dense_29_loss: 16.5216 - dense_30_loss: -0.7630 - val_loss: 7.4022 - val_dense_29_loss: 15.5944 - val_dense_30_loss: -0.7900
Epoch 18/30
5000/5000 [==============================] - 4s 753us/step - loss: 7.7752 - dense_29_loss: 16.3134 - dense_30_loss: -0.7630 - val_loss: 7.2394 - val_dense_29_loss: 15.2688 - val_dense_30_loss: -0.7900
Epoch 19/30
5000/5000 [==============================] - 4s 841us/step - loss: 7.6858 - dense_29_loss: 16.1345 - dense_30_loss: -0.7630 - val_loss: 7.1225 - val_dense_29_loss: 15.0350 - val_dense_30_loss: -0.7900
Epoch 20/30
5000/5000 [==============================] - 4s 801us/step - loss: 7.5966 - dense_29_loss: 15.9562 - dense_30_loss: -0.7630 - val_loss: 6.9988 - val_dense_29_loss: 14.7876 - val_dense_30_loss: -0.7900
Epoch 21/30
5000/5000 [==============================] - 4s 744us/step - loss: 7.5157 - dense_29_loss: 15.7943 - dense_30_loss: -0.7630 - val_loss: 6.9055 - val_dense_29_loss: 14.6010 - val_dense_30_loss: -0.7900
Epoch 22/30
5000/5000 [==============================] - 4s 847us/step - loss: 7.4233 - dense_29_loss: 15.6096 - dense_30_loss: -0.7630 - val_loss: 6.8266 - val_dense_29_loss: 14.4433 - val_dense_30_loss: -0.7900
Epoch 23/30
5000/5000 [==============================] - 3s 668us/step - loss: 7.3547 - dense_29_loss: 15.4724 - dense_30_loss: -0.7630 - val_loss: 6.7169 - val_dense_29_loss: 14.2239 - val_dense_30_loss: -0.7900
Epoch 24/30
5000/5000 [==============================] - 3s 653us/step - loss: 7.2897 - dense_29_loss: 15.3423 - dense_30_loss: -0.7630 - val_loss: 6.6318 - val_dense_29_loss: 14.0535 - val_dense_30_loss: -0.7900
Epoch 25/30
5000/5000 [==============================] - 3s 693us/step - loss: 7.2195 - dense_29_loss: 15.2020 - dense_30_loss: -0.7630 - val_loss: 6.5772 - val_dense_29_loss: 13.9444 - val_dense_30_loss: -0.7900
Epoch 26/30
5000/5000 [==============================] - 4s 846us/step - loss: 7.1611 - dense_29_loss: 15.0852 - dense_30_loss: -0.7630 - val_loss: 6.5330 - val_dense_29_loss: 13.8560 - val_dense_30_loss: -0.7900
Epoch 27/30
5000/5000 [==============================] - 4s 817us/step - loss: 7.1067 - dense_29_loss: 14.9764 - dense_30_loss: -0.7630 - val_loss: 6.4391 - val_dense_29_loss: 13.6681 - val_dense_30_loss: -0.7900
Epoch 28/30
5000/5000 [==============================] - 4s 779us/step - loss: 7.0528 - dense_29_loss: 14.8687 - dense_30_loss: -0.7630 - val_loss: 6.3591 - val_dense_29_loss: 13.5081 - val_dense_30_loss: -0.7900
Epoch 29/30
5000/5000 [==============================] - 4s 728us/step - loss: 6.9904 - dense_29_loss: 14.7439 - dense_30_loss: -0.7630 - val_loss: 6.2726 - val_dense_29_loss: 13.3353 - val_dense_30_loss: -0.7900
Epoch 30/30
5000/5000 [==============================] - 5s 901us/step - loss: 6.9431 - dense_29_loss: 14.6491 - dense_30_loss: -0.7630 - val_loss: 6.2600 - val_dense_29_loss: 13.3100 - val_dense_30_loss: -0.7900





<keras.callbacks.callbacks.History at 0x226ad466748>
with open('test.data', 'rb') as f:
    x_test = np.load(f)
print("number of testing users: ", x_test.shape[0])
number of testing users:  647
x_test = x_test[:600]
x_test.shape
(600, 500)
# x_test[0]
Calculating Recall

The way we are testing the trained system here is something like this. For each patient,

  1. We choose a random diagnosis of the M diagnoses for which the patient has the value 1 (The patient has undergone that diagnosis)
  2. We set that diagnosis to 0.
  3. We pass it through the network to arrive at the probability distribution for the diagnosis codes.
  4. Sort the diagnosis codes by their probabilities.
  5. The network was given an input with M-1 diagnosis. We know calculate the recall@k as the percentage of times the missing diagnosis was seen in the (M-1)+k top spots with respect to its probability.



x_test_hold_new = np.copy(x_test)
hold_out_ind_new = [np.random.choice(np.nonzero(i)[0]) for i in x_test[:,:473]]
for i in range(x_test.shape[0]) :
    x_test_hold_new[i][hold_out_ind_new[i]] = 0
def calc_heldout_recall_new(x_test, x_rec, k):
    count = 1.0
    tot = 1.0
    x_rank = np.argsort(x_rec)
    for i in range(x_rank.shape[0]):
        sm = np.sum(x_test[i])-1
        if sm < 5:
            continue
        else:
            tot +=1
            if hold_out_ind_new[i] in x_rank[i][-(k+sm):]:
                count+=1.0
    return count/tot
x_rec, x_prob = vae.predict(x_test_hold_new, batch_size=batch_size)
for k in [1, 2, 3, 4, 5, 10, 15]:
    print(calc_heldout_recall_new(x_test, x_prob[:,:473], k))
# x_rec, x_prob = vae2.predict(x_test_hold_new, batch_size=batch_size)
x_rec, x_prob = vae.predict(x_test_hold_new, batch_size=batch_size)

for k in [1, 2, 3, 4, 5, 10,15]:
    print(calc_heldout_recall_new(x_test, x_prob[:,:473], k))


x_test_hold = np.copy(x_test)
hold_out_ind = [np.random.choice(np.nonzero(i)[0]) for i in x_test]
for i in range(x_test.shape[0]) :
    x_test_hold[i][hold_out_ind[i]] = 0
def calc_heldout_recall(x_test, x_rec, k):
    count = 1.0
    tot = 1.0
    x_rank = np.argsort(x_rec)
    for i in range(x_rank.shape[0]):
        sm = np.sum(x_test[i])-1
        if sm < 5:
            continue
        else:
            tot +=1
            if hold_out_ind[i] in x_rank[i][-(k+sm):]:
                count+=1.0
    return count/tot
x_rec, x_prob = vae.predict(x_test_hold, batch_size=batch_size)
for k in [1, 2, 3, 4, 5, 10, 15]:
    print(calc_heldout_recall(x_test, x_prob, k))
x_rec, x_prob = vae2.predict(x_test_hold, batch_size=batch_size)
for k in [1, 2, 3, 4, 5, 10]:
    print(calc_heldout_recall(x_test, x_prob, k))
Impact of different ways to calculating the objective functions

We can see that the recall@k where k = 1, 2, 3, 4, 5, 10, 15 is pretty significant, considering that the network had to choose among 1500 other diagnoses.

An interesting observation is that, the second approach of calculating the objective, captures the recalls for smaller 'k’s in a better way compared to the first approach. This is however the opposite when it comes to the larger 'k’s


参考答案

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

shiter

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值