代码来源于github开源库:https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py#L122
contiguous()
方法用于确保张量在内存中是连续存储的,当张量进行形状变换之后,有可能会导致存储顺序不再是连续的,这时就需要使用 .contiguous()
方法来重新安排存储顺序,以确保后续操作能够正确进行。
.embedding.weight.data.uniform_(a,b)用于把权重值限定在a和b之间
接下来定义VQ类,需要传入num_embedding,即码本中码字的数量;embedding_dim,即码字的维度。通过L2距离获得传入隐变量Latents和码本中码字的距离,以最近邻方法得到量化结果。
class VectorQuantizer(nn.Module):
"""
Reference:
[1] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
beta: float = 0.25):
super(VectorQuantizer, self).__init__()
self.K = num_embeddings # 10
self.D = embedding_dim # 1024
self.beta = beta
self.embedding = nn.Embedding(self.K, self.D) # 10*1024
self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K)
def forward(self, latents: Tensor) -> Tensor:
latents = latents.permute(0, 2, 3, 1).contiguous()
# [B x D x H x W] -> [B x H x W x D] [32*1024*3*3] -> [32*3*3*1024]
latents_shape = latents.shape
flat_latents = latents.view(-1, self.D)
# [BHW x D] [288*1024]
# Compute L2 distance between latents and embedding weights
# 计算L2距离以得到码本中最近的码
dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + \
torch.sum(self.embedding.weight ** 2, dim=1) - \
2 * torch.matmul(flat_latents, self.embedding.weight.t())
# [BHW x K] [288 x 10]
# Get the encoding that has the min distance
# 得到L2范式中最小的索引,即码在码本中的位置
encoding_inds = torch.argmin(dist, dim=1).unsqueeze(1) # [288,1]
# Convert to one-hot encodings
device = latents.device
# 得到和索引相同shape的全零矩阵
encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=device) # [288 x 10]
# 将全零矩阵中欧氏距离最小处的索引赋值为1
encoding_one_hot.scatter_(1, encoding_inds, 1) # [288 x 10]
# Quantize the latents 相乘,乘0为0,乘1保持不变,取出码本中的码,即量化后的向量
quantized_latents = torch.matmul(encoding_one_hot, self.embedding.weight) # [BHW, D]
quantized_latents = quantized_latents.view(latents_shape) # [B x H x W x D]
# Compute the VQ Losses
commitment_loss = F.mse_loss(quantized_latents.detach(), latents)
embedding_loss = F.mse_loss(quantized_latents, latents.detach())
vq_loss = commitment_loss * self.beta + embedding_loss
# 保证量化函数和潜变量的梯度相同
quantized_latents = latents + (quantized_latents - latents).detach()
return quantized_latents.permute(0, 3, 1, 2).contiguous(), vq_loss # [B x D x H x W]