Using Variational Autoencoders to predict future diagnosis
VAE for Collaborative Filtering
This work is an adaptation of the work by Dawen et.al who used VAEs for the purpose of Collaborative filtering. The work by Dawen et.al exploits the Generative nature of VAEs to arrive at a completed user-preference information given an input of partial user-preference information.
Work by Dawen et.al: https://arxiv.org/abs/1802.05814
Diagnosis Codes
In the healthcare industry, the diagnosis any patient encounters, has been standardized with diagnosis codes. Each disease or a medical condition is mapped to a diagnosis code.
ICD10 Diagnosis Codes: https://www.icd10data.com/
Data
In the data we are using, we have information on a set of patients and the diagnoses that they have undergone. Each of these diagnoses are mapped to diagnosis codes. The data contains a total of 1567 unique diagnosis codes.
So a given patient is represented by a binary vector of dimension 1567 where an element is 1 if that patient has undergone the particular diagnosis and 0 otherwise.
VAE for next diagnosis prediction
Now given the patient diagnosis information, the VAE encodes it into a latent space. It learns the information on distribution of patients and the clusters of diagnoses they undergo.
To provide a simple example if you consider diabetes in older adults, the group of diagnosis that would appear commonly among such adults would be something like, Diabetes, Cholesterol, Blood Pressure, Arthritics etc,.
Now given a patient with a diagnosis set which says something like, Diabetes and Cholesterol, this patient would get mapped to the same space in the latent dimension as the older adults with the diagnosis mentioned earlier.
This mapping of similar patients to similar latent space has a very favourable impact on decoding/reconstruction to original space. On decoding, what happens is that the, missing diagnosis with high probabilities of occurence, for a particular patient is also reconstructed.
This ability to fill in the missing diagnosis in the form of a Collaborative Filtering of sorts is why I apply this technique to predict the next diagnosis.
Applications
This work can be used for many applications ranging from insurance companies using it to better predict a patient’s needs to healthcare applications which encourage people to improve their life-style choices.
Some background
I came across the application of VAEs for Collaborative filtering, when I studied it for my previous work “Hybrid VAE for Collaborative Filtering”. This work processes the movie plot information from IMDb and uses it as an input to improve movie recommendation systems. This particular work was published in RecSys 2018 Knowledge Transfer Learning workshop: https://arxiv.org/abs/1808.01006
IMPLEMENTATION
%matplotlib inline
import numpy as np
import pickle
import os
from matplotlib import pyplot as plt
from keras.layers import Input, Dense, Lambda, Multiply, Dropout,Embedding, Flatten, Activation, Reshape
from keras.models import Model
from keras import losses
from keras import backend as K
from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, Callback
from IPython.display import clear_output
from sklearn import preprocessing
from keras import regularizers
import keras
import pandas as pd
import numpy as np
Using TensorFlow backend.
import os
os.getcwd()
'C:\\Users\\iz\\Desktop'
df = pd.read_csv('test.csv')
df = df.drop(['KEY'],axis = 1)
df.head()
T40 | A08 | I69 | Z48 | R44 | N92 | R59 | B97 | M96 | I35 | ... | H61 | T84 | M16 | J38 | Z90 | D68 | K83 | Z87 | Z75 | Z43 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 0 |
1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 0 |
2 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 0 |
3 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 |
4 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
5 rows × 500 columns
from sklearn.model_selection import train_test_split
y = range(df.shape[0])
xtrain, xtest, ytrain, ytest = train_test_split(df, y, test_size = 0.1, random_state = 42)
xtrain, xval, ytrain, yval = train_test_split(xtrain, ytrain, test_size = 0.1, random_state = 42)
# import numpy as np
with open('./train.data', 'wb') as f:
np.save(f, xtrain)
with open('./test.data', 'wb') as f:
np.save(f, xtest)
with open('./val.data', 'wb') as f:
np.save(f, xval)
Function to plot Losses
class PlotLosses(Callback):
def on_train_begin(self, logs={}):
self.i = 0
self.x = []
self.losses = []
self.val_losses = []
self.fig = plt.figure()
self.logs = []
def on_epoch_end(self, epoch, logs={}):
self.logs.append(logs)
self.x.append(self.i)
self.losses.append(logs.get('loss'))
self.val_losses.append(logs.get('val_loss'))
self.i += 1
clear_output(wait=True)
plt.plot(self.x, self.losses, label="loss")
plt.plot(self.x, self.val_losses, label="val_loss")
plt.legend()
plt.show();
plot_losses = PlotLosses()
Load Data
with open('train.data', 'rb') as f:
x_train = np.load(f)
print("number of training users: ", x_train.shape[0])
with open('val.data', 'rb') as f:
x_val = np.load(f)
print("number of validation users: ", x_val.shape[0])
number of training users: 5234
number of validation users: 582
x_train.shape,x_val.shape
((5234, 500), (582, 500))
x_train[0].shape
(500,)
x_train = x_train[:5000]
x_val = x_val[:500]
Configure Network
# encoder/decoder network size
batch_size=100
original_dim = x_train.shape[1]
intermediate_dim=200
latent_dim=100
nb_epochs=30
epsilon_std=1.0
Using a two output network
Here, we have two outputs from the network, which is much different compared to the original VAE network proposed by Dawen et.al.
The first output reconstructs the given input, while the second output gives out a probability distribution over the Diagnosis codes. Each of them have a specific loss function that maximizes the particular objective.
Network 1: Same objective functions for the two outputs
#Function to increase the relevance of the KL regularization as the training progresses
class increaseBeta(Callback):
def __init__(self):
self.global_beta = 0.0
def on_train_begin(self, logs={}):
self.global_beta = 0.0
def on_epoch_end(self, epoch, logs={}):
self.global_beta = self.global_beta + 0.01
updateBeta = increaseBeta()
#Function to l2 normalize the inputs
def l2normalize(args):
_x=args
return K.l2_normalize(_x, axis = -1)
#Function to do the sampling from Latent Space
def sampling(args):
_mean,_log_var=args
epsilon=K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0., stddev=epsilon_std)
return _mean+K.exp(_log_var/2)*epsilon
# encoder network
x=Input(batch_shape=(batch_size,original_dim))
norm_x = Lambda(l2normalize, output_shape=(original_dim,))(x)
norm_x = Dropout(rate = 0.5)(norm_x)
h=Dense(intermediate_dim, activation='relu')(norm_x)
z_mean=Dense(latent_dim)(h)
z_log_var=Dense(latent_dim)(h)
z= Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])
# decoder network
h_decoder=Dense(intermediate_dim, activation='relu')
x_bar=Dense(original_dim, activation='sigmoid')
x_prob=Dense(original_dim, activation='softmax')
h_decoded = h_decoder(z)
# We have two outputs, one which reconstructs the given input, the other which reconstructs the probability
x_decoded = x_bar(h_decoded)
x_probability = x_prob(h_decoded)
def vae_loss(x,x_bar):
reconst_loss = K.sum(losses.binary_crossentropy(x,x_bar), axis = -1)
kl_loss = K.sum( 0.5 * (K.exp(z_log_var) - z_log_var + K.square(z_mean) - 1), axis=-1)
return reconst_loss + (updateBeta.global_beta)*kl_loss
# build and compile model
vae = Model(x, [x_decoded, x_probability])
vae.compile(optimizer='adam', loss=vae_loss, loss_weights=[1., 1.])
weightsPath = "./weights/weights_vae1.hdf5"
x_train,y = [x_train, x_train], batch_size = batch_size, epochs=30,\
# validation_data=(x_val, [x_val, x_val]), callbacks=[checkpointer, reduce_lr, plot_losses, updateBeta])
vae.fit(x = x_train,y = [x_train, x_train], batch_size = batch_size, epochs=30,\
validation_data=(x_val, [x_val, x_val]), )# checkpointer = ModelCheckpoint(filepath=weightsPath, verbose=1, save_best_only=True)
# reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.001)
# vae.fit(x =
Train on 5000 samples, validate on 500 samples
Epoch 1/30
5000/5000 [==============================] - 4s 758us/step - loss: 78.1787 - dense_23_loss: 34.3244 - dense_24_loss: 43.8543 - val_loss: 64.6742 - val_dense_23_loss: 22.1748 - val_dense_24_loss: 42.4994
Epoch 2/30
5000/5000 [==============================] - 4s 743us/step - loss: 64.4046 - dense_23_loss: 21.9435 - dense_24_loss: 42.4611 - val_loss: 63.6366 - val_dense_23_loss: 21.5553 - val_dense_24_loss: 42.0814
Epoch 3/30
5000/5000 [==============================] - 4s 837us/step - loss: 63.6365 - dense_23_loss: 21.5137 - dense_24_loss: 42.1228 - val_loss: 62.9220 - val_dense_23_loss: 21.1661 - val_dense_24_loss: 41.7559
Epoch 4/30
5000/5000 [==============================] - 4s 727us/step - loss: 62.5700 - dense_23_loss: 20.9471 - dense_24_loss: 41.6229 - val_loss: 61.4582 - val_dense_23_loss: 20.4014 - val_dense_24_loss: 41.0568
Epoch 5/30
5000/5000 [==============================] - 4s 733us/step - loss: 61.5120 - dense_23_loss: 20.3677 - dense_24_loss: 41.1443 - val_loss: 60.5572 - val_dense_23_loss: 19.8999 - val_dense_24_loss: 40.6573 los
Epoch 6/30
5000/5000 [==============================] - 3s 669us/step - loss: 60.6162 - dense_23_loss: 19.8809 - dense_24_loss: 40.7353 - val_loss: 59.4887 - val_dense_23_loss: 19.3130 - val_dense_24_loss: 40.1757
Epoch 7/30
5000/5000 [==============================] - 4s 879us/step - loss: 59.7845 - dense_23_loss: 19.4161 - dense_24_loss: 40.3683 - val_loss: 58.6560 - val_dense_23_loss: 18.8436 - val_dense_24_loss: 39.8125
Epoch 8/30
5000/5000 [==============================] - 3s 581us/step - loss: 59.0007 - dense_23_loss: 18.9740 - dense_24_loss: 40.0267 - val_loss: 57.7723 - val_dense_23_loss: 18.3500 - val_dense_24_loss: 39.4223
Epoch 9/30
5000/5000 [==============================] - 3s 664us/step - loss: 58.1647 - dense_23_loss: 18.4887 - dense_24_loss: 39.6760 - val_loss: 56.8185 - val_dense_23_loss: 17.8036 - val_dense_24_loss: 39.0149
Epoch 10/30
5000/5000 [==============================] - 4s 705us/step - loss: 57.3740 - dense_23_loss: 18.0296 - dense_24_loss: 39.3444 - val_loss: 56.0715 - val_dense_23_loss: 17.3908 - val_dense_24_loss: 38.6807
Epoch 11/30
5000/5000 [==============================] - 4s 711us/step - loss: 56.7545 - dense_23_loss: 17.6735 - dense_24_loss: 39.0810 - val_loss: 55.1875 - val_dense_23_loss: 16.8682 - val_dense_24_loss: 38.3194
Epoch 12/30
5000/5000 [==============================] - 3s 550us/step - loss: 56.1791 - dense_23_loss: 17.3366 - dense_24_loss: 38.8425 - val_loss: 54.6427 - val_dense_23_loss: 16.5640 - val_dense_24_loss: 38.0787
Epoch 13/30
5000/5000 [==============================] - 3s 568us/step - loss: 55.6814 - dense_23_loss: 17.0418 - dense_24_loss: 38.6397 - val_loss: 54.1010 - val_dense_23_loss: 16.2432 - val_dense_24_loss: 37.8578
Epoch 14/30
500/5000 [==>...........................] - ETA: 2s - loss: 54.4526 - dense_23_loss: 16.7270 - dense_24_loss: 37.7256
C:\ProgramData\Anaconda3\envs\zhongdian\lib\site-packages\keras\callbacks\callbacks.py:95: RuntimeWarning: Method (on_train_batch_end) is slow compared to the batch update (0.111798). Check your callbacks.
% (hook_name, delta_t_median), RuntimeWarning)
5000/5000 [==============================] - 3s 624us/step - loss: 55.2011 - dense_23_loss: 16.7642 - dense_24_loss: 38.4370 - val_loss: 53.4904 - val_dense_23_loss: 15.8693 - val_dense_24_loss: 37.6210
Epoch 15/30
5000/5000 [==============================] - ETA: 0s - loss: 54.8870 - dense_23_loss: 16.5693 - dense_24_loss: 38.31 - 3s 699us/step - loss: 54.8549 - dense_23_loss: 16.5582 - dense_24_loss: 38.2967 - val_loss: 53.0989 - val_dense_23_loss: 15.6399 - val_dense_24_loss: 37.4590
Epoch 16/30
5000/5000 [==============================] - 3s 612us/step - loss: 54.5157 - dense_23_loss: 16.3574 - dense_24_loss: 38.1583 - val_loss: 52.5809 - val_dense_23_loss: 15.3208 - val_dense_24_loss: 37.2601
Epoch 17/30
5000/5000 [==============================] - 3s 628us/step - loss: 54.2176 - dense_23_loss: 16.1803 - dense_24_loss: 38.0373 - val_loss: 52.2525 - val_dense_23_loss: 15.1291 - val_dense_24_loss: 37.1233
Epoch 18/30
5000/5000 [==============================] - 4s 744us/step - loss: 53.8684 - dense_23_loss: 15.9648 - dense_24_loss: 37.9036 - val_loss: 51.7984 - val_dense_23_loss: 14.8564 - val_dense_24_loss: 36.9420 - dense_23_
Epoch 19/30
5000/5000 [==============================] - 3s 600us/step - loss: 53.5824 - dense_23_loss: 15.7955 - dense_24_loss: 37.7869 - val_loss: 51.4563 - val_dense_23_loss: 14.6374 - val_dense_24_loss: 36.8189
Epoch 20/30
5000/5000 [==============================] - 3s 534us/step - loss: 53.3213 - dense_23_loss: 15.6395 - dense_24_loss: 37.6818 - val_loss: 51.2304 - val_dense_23_loss: 14.5035 - val_dense_24_loss: 36.7269
Epoch 21/30
5000/5000 [==============================] - 4s 701us/step - loss: 53.0680 - dense_23_loss: 15.4855 - dense_24_loss: 37.5825 - val_loss: 50.9055 - val_dense_23_loss: 14.3052 - val_dense_24_loss: 36.6004
Epoch 22/30
5000/5000 [==============================] - 4s 735us/step - loss: 52.8112 - dense_23_loss: 15.3356 - dense_24_loss: 37.4755 - val_loss: 50.6851 - val_dense_23_loss: 14.1706 - val_dense_24_loss: 36.5145
Epoch 23/30
5000/5000 [==============================] - 4s 733us/step - loss: 52.6116 - dense_23_loss: 15.2092 - dense_24_loss: 37.4025 - val_loss: 50.3594 - val_dense_23_loss: 13.9758 - val_dense_24_loss: 36.3836
Epoch 24/30
5000/5000 [==============================] - 4s 849us/step - loss: 52.4208 - dense_23_loss: 15.0949 - dense_24_loss: 37.3259 - val_loss: 50.1555 - val_dense_23_loss: 13.8413 - val_dense_24_loss: 36.3142
Epoch 25/30
5000/5000 [==============================] - 4s 732us/step - loss: 52.2180 - dense_23_loss: 14.9752 - dense_24_loss: 37.2428 - val_loss: 49.9194 - val_dense_23_loss: 13.6918 - val_dense_24_loss: 36.2276
Epoch 26/30
5000/5000 [==============================] - 3s 604us/step - loss: 52.0623 - dense_23_loss: 14.8808 - dense_24_loss: 37.1815 - val_loss: 49.7312 - val_dense_23_loss: 13.5534 - val_dense_24_loss: 36.1778
Epoch 27/30
5000/5000 [==============================] - 3s 507us/step - loss: 51.8958 - dense_23_loss: 14.7736 - dense_24_loss: 37.1222 - val_loss: 49.5618 - val_dense_23_loss: 13.4641 - val_dense_24_loss: 36.0977
Epoch 28/30
5000/5000 [==============================] - 3s 566us/step - loss: 51.7471 - dense_23_loss: 14.6879 - dense_24_loss: 37.0592 - val_loss: 49.3073 - val_dense_23_loss: 13.3000 - val_dense_24_loss: 36.0073
Epoch 29/30
5000/5000 [==============================] - 3s 595us/step - loss: 51.5263 - dense_23_loss: 14.5482 - dense_24_loss: 36.9781 - val_loss: 49.1972 - val_dense_23_loss: 13.2198 - val_dense_24_loss: 35.9775
Epoch 30/30
5000/5000 [==============================] - 3s 551us/step - loss: 51.4380 - dense_23_loss: 14.4985 - dense_24_loss: 36.9395 - val_loss: 48.9715 - val_dense_23_loss: 13.0922 - val_dense_24_loss: 35.8793
<keras.callbacks.callbacks.History at 0x226a912c6d8>
Network 2: Different objective functions for the two outputs
# Function to increase the relevance of the KL regularization as the training progresses
class increaseBeta(Callback):
def __init__(self):
self.global_beta = 0.0
def on_train_begin(self, logs={}):
self.global_beta = 0.0
def on_epoch_end(self, epoch, logs={}):
self.global_beta = self.global_beta + 0.01
updateBeta = increaseBeta()
#Function to l2 normalize the inputs
def l2normalize(args):
_x=args
return K.l2_normalize(_x, axis = -1)
#Function to do the sampling from Latent Space
def sampling(args):
_mean,_log_var=args
epsilon=K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0., stddev=epsilon_std)
return _mean+K.exp(_log_var/2)*epsilon
# encoder network
x=Input(batch_shape=(batch_size,original_dim))
norm_x = Lambda(l2normalize, output_shape=(original_dim,))(x)
norm_x = Dropout(rate = 0.5)(norm_x)
h=Dense(intermediate_dim, activation='relu')(norm_x)
z_mean=Dense(latent_dim)(h)
z_log_var=Dense(latent_dim)(h)
z= Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])
# decoder network
h_decoder=Dense(intermediate_dim, activation='relu')
x_bar=Dense(original_dim, activation='sigmoid')
x_prob=Dense(original_dim, activation='softmax')
h_decoded = h_decoder(z)
#We have two outputs, one which reconstructs the given input, the other which reconstructs the probability
x_decoded = x_bar(h_decoded)
x_probability = x_prob(h_decoded)
def vae_loss1(x,x_bar):
reconst_loss = K.sum(losses.binary_crossentropy(x,x_bar), axis = -1)
kl_loss = K.sum( 0.5 * (K.exp(z_log_var) - z_log_var + K.square(z_mean) - 1), axis=-1)
return reconst_loss + (updateBeta.global_beta)*kl_loss
def vae_loss2(x,x_bar):
neg_ll = -K.sum(x_bar*x, axis = -1)
kl_loss = K.sum( 0.5 * (K.exp(z_log_var) - z_log_var + K.square(z_mean) - 1), axis=-1)
return neg_ll + (updateBeta.global_beta)*kl_loss
# build and compile model
vae2 = Model(x, [x_decoded, x_probability])
vae2.compile(optimizer='adam', loss=[vae_loss1, vae_loss2], loss_weights=[0.5, 0.5])
# weightsPath = "./weights/weights_vae2.hdf5"
# checkpointer = ModelCheckpoint(filepath=weightsPath, verbose=1, save_best_only=True)
# reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.001)
# vae2.fit(x = x_train,y = [x_train, x_train], batch_size = batch_size, epochs=30,\
# validation_data=(x_val, [x_val, x_val]), callbacks=[checkpointer, reduce_lr, plot_losses, updateBeta])
vae2.fit(x = x_train,y = [x_train, x_train], batch_size = batch_size, epochs=30,\
validation_data=(x_val, [x_val, x_val]), )
Train on 5000 samples, validate on 500 samples
Epoch 1/30
5000/5000 [==============================] - 4s 729us/step - loss: 16.9416 - dense_29_loss: 34.2665 - dense_30_loss: -0.3832 - val_loss: 10.8511 - val_dense_29_loss: 22.3960 - val_dense_30_loss: -0.6938
Epoch 2/30
5000/5000 [==============================] - 3s 596us/step - loss: 10.6567 - dense_29_loss: 22.0604 - dense_30_loss: -0.7470 - val_loss: 10.5063 - val_dense_29_loss: 21.7948 - val_dense_30_loss: -0.7821
Epoch 3/30
5000/5000 [==============================] - 3s 587us/step - loss: 10.4868 - dense_29_loss: 21.7343 - dense_30_loss: -0.7607 - val_loss: 10.3348 - val_dense_29_loss: 21.4568 - val_dense_30_loss: -0.7872
Epoch 4/30
5000/5000 [==============================] - 3s 588us/step - loss: 10.3471 - dense_29_loss: 21.4561 - dense_30_loss: -0.7619 - val_loss: 10.2142 - val_dense_29_loss: 21.2168 - val_dense_30_loss: -0.7883
Epoch 5/30
5000/5000 [==============================] - 3s 583us/step - loss: 10.1537 - dense_29_loss: 21.0695 - dense_30_loss: -0.7622 - val_loss: 9.8815 - val_dense_29_loss: 20.5522 - val_dense_30_loss: -0.7891
Epoch 6/30
5000/5000 [==============================] - 4s 700us/step - loss: 9.8399 - dense_29_loss: 20.4424 - dense_30_loss: -0.7627 - val_loss: 9.6098 - val_dense_29_loss: 20.0090 - val_dense_30_loss: -0.7894
Epoch 7/30
5000/5000 [==============================] - 3s 661us/step - loss: 9.6009 - dense_29_loss: 19.9645 - dense_30_loss: -0.7627 - val_loss: 9.3354 - val_dense_29_loss: 19.4604 - val_dense_30_loss: -0.7896
Epoch 8/30
5000/5000 [==============================] - 4s 732us/step - loss: 9.3988 - dense_29_loss: 19.5603 - dense_30_loss: -0.7628 - val_loss: 9.1374 - val_dense_29_loss: 19.0644 - val_dense_30_loss: -0.7897
Epoch 9/30
5000/5000 [==============================] - 3s 698us/step - loss: 9.1813 - dense_29_loss: 19.1254 - dense_30_loss: -0.7629 - val_loss: 8.9246 - val_dense_29_loss: 18.6389 - val_dense_30_loss: -0.7896
Epoch 10/30
5000/5000 [==============================] - 4s 761us/step - loss: 8.9580 - dense_29_loss: 18.6789 - dense_30_loss: -0.7629 - val_loss: 8.6732 - val_dense_29_loss: 18.1362 - val_dense_30_loss: -0.7898
Epoch 11/30
5000/5000 [==============================] - 3s 695us/step - loss: 8.7604 - dense_29_loss: 18.2836 - dense_30_loss: -0.7629 - val_loss: 8.3996 - val_dense_29_loss: 17.5891 - val_dense_30_loss: -0.7899
Epoch 12/30
5000/5000 [==============================] - 3s 628us/step - loss: 8.5640 - dense_29_loss: 17.8911 - dense_30_loss: -0.7630 - val_loss: 8.1917 - val_dense_29_loss: 17.1734 - val_dense_30_loss: -0.7899
Epoch 13/30
5000/5000 [==============================] - 3s 616us/step - loss: 8.3888 - dense_29_loss: 17.5405 - dense_30_loss: -0.7630 - val_loss: 7.9692 - val_dense_29_loss: 16.7284 - val_dense_30_loss: -0.7900
Epoch 14/30
5000/5000 [==============================] - 3s 638us/step - loss: 8.2512 - dense_29_loss: 17.2653 - dense_30_loss: -0.7630 - val_loss: 7.8081 - val_dense_29_loss: 16.4062 - val_dense_30_loss: -0.7900
Epoch 15/30
5000/5000 [==============================] - 4s 777us/step - loss: 8.1205 - dense_29_loss: 17.0040 - dense_30_loss: -0.7630 - val_loss: 7.6419 - val_dense_29_loss: 16.0737 - val_dense_30_loss: -0.7900
Epoch 16/30
5000/5000 [==============================] - 5s 972us/step - loss: 7.9886 - dense_29_loss: 16.7401 - dense_30_loss: -0.7630 - val_loss: 7.5225 - val_dense_29_loss: 15.8349 - val_dense_30_loss: -0.7900
Epoch 17/30
5000/5000 [==============================] - 4s 724us/step - loss: 7.8793 - dense_29_loss: 16.5216 - dense_30_loss: -0.7630 - val_loss: 7.4022 - val_dense_29_loss: 15.5944 - val_dense_30_loss: -0.7900
Epoch 18/30
5000/5000 [==============================] - 4s 753us/step - loss: 7.7752 - dense_29_loss: 16.3134 - dense_30_loss: -0.7630 - val_loss: 7.2394 - val_dense_29_loss: 15.2688 - val_dense_30_loss: -0.7900
Epoch 19/30
5000/5000 [==============================] - 4s 841us/step - loss: 7.6858 - dense_29_loss: 16.1345 - dense_30_loss: -0.7630 - val_loss: 7.1225 - val_dense_29_loss: 15.0350 - val_dense_30_loss: -0.7900
Epoch 20/30
5000/5000 [==============================] - 4s 801us/step - loss: 7.5966 - dense_29_loss: 15.9562 - dense_30_loss: -0.7630 - val_loss: 6.9988 - val_dense_29_loss: 14.7876 - val_dense_30_loss: -0.7900
Epoch 21/30
5000/5000 [==============================] - 4s 744us/step - loss: 7.5157 - dense_29_loss: 15.7943 - dense_30_loss: -0.7630 - val_loss: 6.9055 - val_dense_29_loss: 14.6010 - val_dense_30_loss: -0.7900
Epoch 22/30
5000/5000 [==============================] - 4s 847us/step - loss: 7.4233 - dense_29_loss: 15.6096 - dense_30_loss: -0.7630 - val_loss: 6.8266 - val_dense_29_loss: 14.4433 - val_dense_30_loss: -0.7900
Epoch 23/30
5000/5000 [==============================] - 3s 668us/step - loss: 7.3547 - dense_29_loss: 15.4724 - dense_30_loss: -0.7630 - val_loss: 6.7169 - val_dense_29_loss: 14.2239 - val_dense_30_loss: -0.7900
Epoch 24/30
5000/5000 [==============================] - 3s 653us/step - loss: 7.2897 - dense_29_loss: 15.3423 - dense_30_loss: -0.7630 - val_loss: 6.6318 - val_dense_29_loss: 14.0535 - val_dense_30_loss: -0.7900
Epoch 25/30
5000/5000 [==============================] - 3s 693us/step - loss: 7.2195 - dense_29_loss: 15.2020 - dense_30_loss: -0.7630 - val_loss: 6.5772 - val_dense_29_loss: 13.9444 - val_dense_30_loss: -0.7900
Epoch 26/30
5000/5000 [==============================] - 4s 846us/step - loss: 7.1611 - dense_29_loss: 15.0852 - dense_30_loss: -0.7630 - val_loss: 6.5330 - val_dense_29_loss: 13.8560 - val_dense_30_loss: -0.7900
Epoch 27/30
5000/5000 [==============================] - 4s 817us/step - loss: 7.1067 - dense_29_loss: 14.9764 - dense_30_loss: -0.7630 - val_loss: 6.4391 - val_dense_29_loss: 13.6681 - val_dense_30_loss: -0.7900
Epoch 28/30
5000/5000 [==============================] - 4s 779us/step - loss: 7.0528 - dense_29_loss: 14.8687 - dense_30_loss: -0.7630 - val_loss: 6.3591 - val_dense_29_loss: 13.5081 - val_dense_30_loss: -0.7900
Epoch 29/30
5000/5000 [==============================] - 4s 728us/step - loss: 6.9904 - dense_29_loss: 14.7439 - dense_30_loss: -0.7630 - val_loss: 6.2726 - val_dense_29_loss: 13.3353 - val_dense_30_loss: -0.7900
Epoch 30/30
5000/5000 [==============================] - 5s 901us/step - loss: 6.9431 - dense_29_loss: 14.6491 - dense_30_loss: -0.7630 - val_loss: 6.2600 - val_dense_29_loss: 13.3100 - val_dense_30_loss: -0.7900
<keras.callbacks.callbacks.History at 0x226ad466748>
with open('test.data', 'rb') as f:
x_test = np.load(f)
print("number of testing users: ", x_test.shape[0])
number of testing users: 647
x_test = x_test[:600]
x_test.shape
(600, 500)
# x_test[0]
Calculating Recall
The way we are testing the trained system here is something like this. For each patient,
- We choose a random diagnosis of the M diagnoses for which the patient has the value 1 (The patient has undergone that diagnosis)
- We set that diagnosis to 0.
- We pass it through the network to arrive at the probability distribution for the diagnosis codes.
- Sort the diagnosis codes by their probabilities.
- The network was given an input with M-1 diagnosis. We know calculate the recall@k as the percentage of times the missing diagnosis was seen in the (M-1)+k top spots with respect to its probability.
x_test_hold_new = np.copy(x_test)
hold_out_ind_new = [np.random.choice(np.nonzero(i)[0]) for i in x_test[:,:473]]
for i in range(x_test.shape[0]) :
x_test_hold_new[i][hold_out_ind_new[i]] = 0
def calc_heldout_recall_new(x_test, x_rec, k):
count = 1.0
tot = 1.0
x_rank = np.argsort(x_rec)
for i in range(x_rank.shape[0]):
sm = np.sum(x_test[i])-1
if sm < 5:
continue
else:
tot +=1
if hold_out_ind_new[i] in x_rank[i][-(k+sm):]:
count+=1.0
return count/tot
x_rec, x_prob = vae.predict(x_test_hold_new, batch_size=batch_size)
for k in [1, 2, 3, 4, 5, 10, 15]:
print(calc_heldout_recall_new(x_test, x_prob[:,:473], k))
# x_rec, x_prob = vae2.predict(x_test_hold_new, batch_size=batch_size)
x_rec, x_prob = vae.predict(x_test_hold_new, batch_size=batch_size)
for k in [1, 2, 3, 4, 5, 10,15]:
print(calc_heldout_recall_new(x_test, x_prob[:,:473], k))
x_test_hold = np.copy(x_test)
hold_out_ind = [np.random.choice(np.nonzero(i)[0]) for i in x_test]
for i in range(x_test.shape[0]) :
x_test_hold[i][hold_out_ind[i]] = 0
def calc_heldout_recall(x_test, x_rec, k):
count = 1.0
tot = 1.0
x_rank = np.argsort(x_rec)
for i in range(x_rank.shape[0]):
sm = np.sum(x_test[i])-1
if sm < 5:
continue
else:
tot +=1
if hold_out_ind[i] in x_rank[i][-(k+sm):]:
count+=1.0
return count/tot
x_rec, x_prob = vae.predict(x_test_hold, batch_size=batch_size)
for k in [1, 2, 3, 4, 5, 10, 15]:
print(calc_heldout_recall(x_test, x_prob, k))
x_rec, x_prob = vae2.predict(x_test_hold, batch_size=batch_size)
for k in [1, 2, 3, 4, 5, 10]:
print(calc_heldout_recall(x_test, x_prob, k))
Impact of different ways to calculating the objective functions
We can see that the recall@k where k = 1, 2, 3, 4, 5, 10, 15 is pretty significant, considering that the network had to choose among 1500 other diagnoses.
An interesting observation is that, the second approach of calculating the objective, captures the recalls for smaller 'k’s in a better way compared to the first approach. This is however the opposite when it comes to the larger 'k’s