notations
- x x x image
- z z z latent
- y y y label (omitted to lighten notation)
- p ( x ∣ z ) p(x|z) p(x∣z) decoder
Encoder
- q ( z ∣ x ) q(z|x) q(z∣x) encoder
Decoder
- p ^ ( h ) \hat{p}(h) p^(h) prior encoder (by variational inference)
PriorEncoder
model structure
class CVAE(nn.Module):
def __init__(self, config):
super(CVAE, self).__init__()
self.encoder = Encoder(...)
self.decoder = Decoder(...)
self.priorEncoder = PriorEncoder(...)
def forward(self, x, y):
x = x.reshape((-1, 784)) # MNIST
mu, sigma = self.encoder(x, y)
prior_mu, prior_sigma = self.priorEncoder(y)
z = torch.randn_like(mu)
z = z * sigma + mu
reconstructed_x = self.decoder(z, y)
reconstructed_x = reconstructed_x.reshape((-1, 28, 28))
return reconstructed_x, mu, sigma, prior_mu, prior_sigma
def infer(self, y):
prior_mu, prior_sigma = self.priorEncoder(y)
z = torch.randn_like(prior_mu)
z = z * prior_sigma + prior_mu
reconstructed_x = self.decoder(z, y)
return reconstructed_x
#
class Loss(nn.Module):
def __init__(self):
super(Loss,self).__init__()
self.loss_fn = nn.MSELoss(reduction='mean')
self.kld_loss_weight = 1e-5
def forward(self, x, reconstructed_x, mu, sigma, prior_mu, prior_sigma):
mse_loss = self.loss_fn(x, reconstructed_x)
kld_loss = torch.log(prior_sigma / sigma) + (sigma**2 + (mu - prior_mu)**2) / (2 * prior_sigma**2) - 0.5
kld_loss = torch.sum(kld_loss) / x.shape[0]
loss = mse_loss + self.kld_loss_weight * kld_loss
return loss
#
def train(model, criterion, optimizer, data_loader, config):
train_task_time_str = time_str()
for epoch in range(config.num_epoch):
loss_seq = []
for step, (x,y)