MATLAB实现自编码器(五)——变分自编码器(VAE)实现图像生成的帮助函数

本文是对Train Variational Autoencoder (VAE) to Generate Images网页的翻译,该网页实现了变分自编码的图像生成,以MNIST手写数字为训练数据,生成了相似的图像。本文主要翻译了网页中帮助函数外的部分。主要部分见MATLAB实现自编码器(四)——变分自编码器实现图像生成Train Variational Autoencoder (VAE) to Generate Images

processImagesMNIST

首先是两个用于处理mnist数据集的函数,分别处理图片和标签,使其符合网络的输入要求。

function X = processImagesMNIST(filename)
% The MNIST processing functions extract the data from the downloaded IDX
% files into MATLAB arrays. The processImagesMNIST function performs these
% operations: Check if the file can be opened correctly. Obtain the magic
% number by reading the first four bytes. The magic number is 2051 for
% image data, and 2049 for label data. Read the next 3 sets of 4 bytes,
% which return the number of images, the number of rows, and the number of
% columns. Read the image data. Reshape the array and swaps the first two
% dimensions due to the fact that the data was being read in column major
% format. Ensure the pixel values are in the range  [0,1] by dividing them
% all by 255, and converts the 3-D array to a 4-D dlarray object. Close the
% file.

[fileID,errmsg] = fopen(filename,'r','b');
if fileID < 0
    error(errmsg);
end

magicNum = fread(fileID,1,'int32',0,'b');
if magicNum == 2051
    fprintf('\nRead MNIST image data...\n')
end

numImages = fread(fileID,1,'int32',0,'b');
fprintf('Number of images in the dataset: %6d ...\n',numImages);
numRows = fread(fileID,1,'int32',0,'b');
numCols = fread(fileID,1,'int32',0,'b');

X = fread(fileID,inf,'unsigned char');

X = reshape(X,numCols,numRows,numImages);
X = permute(X,[2 1 3]);
X = X./255;
X = reshape(X, [28,28,1,size(X,3)]);
X = dlarray(X, 'SSCB');

fclose(fileID);
end

processImagesMNIST

处理标签,使其符合网络的输入要求

function Y = processLabelsMNIST(filename)
% The processLabelsMNIST function operates similarly to the
% processImagesMNIST function. After opening the file and reading the magic
% number, it reads the labels and returns a categorical array containing
% their values.

[fileID,errmsg] = fopen(filename,'r','b');

if fileID < 0
    error(errmsg);
end

magicNum = fread(fileID,1,'int32',0,'b');
if magicNum == 2049
    fprintf('\nRead MNIST label data...\n')
end

numItems = fread(fileID,1,'int32',0,'b');
fprintf('Number of labels in the dataset: %6d ...\n',numItems);

Y = fread(fileID,inf,'unsigned char');

Y = categorical(Y);

fclose(fileID);
end

Model Gradients Function

The modelGradients function takes the encoder and decoder dlnetwork objects and a mini-batch of input data X, and returns the gradients of the loss with respect to the learnable parameters in the networks. The function performs three operations:

  • Obtain the encodings by calling the sampling function on the mini-batch of images that passes through the encoder network.
  • Obtain the loss by passing the encodings through the decoder network and calling the ELBOloss function.
  • Compute the gradients of the loss with respect to the learnable parameters of both networks by calling the dlgradient function.

modelGradients函数获取编码器和解码器的dlnetwork对象以及输入数据X的小批量,并返回网络中可训练参数的损失梯度。 该函数执行三个操作:

  • 通过在通过编码器网络的微型图像批次上调用采样函数来获取编码。
  • 通过使编码通过解码器网络并调用ELBOloss函数来获得损耗。
  • 通过调用dlgradient函数,针对两个网络的可学习参数计算损耗的梯度。
function [infGrad, genGrad] = modelGradients(encoderNet, decoderNet, x)
[z, zMean, zLogvar] = sampling(encoderNet, x);
xPred = sigmoid(forward(decoderNet, z));
loss = ELBOloss(x, xPred, zMean, zLogvar);
[genGrad, infGrad] = dlgradient(loss, decoderNet.Learnables, ...
    encoderNet.Learnables);
end

Sampling and Loss Functions

The sampling function obtains encodings from input images. Initially, it passes a mini-batch of images through the encoder network and splits the output of size (2*latentDim)miniBatchSize into a matrix of means and a matrix of variances, each of size latentDimbatchSize. Then, it uses these matrices to implement the reparameterization trick and to compute the encoding. Finally, it converts this encoding to a dlarray object in SSCB format.

Sampling 函数从输入图像获取编码。 最初,它通过编码器网络传递一个图像的小批量,并将大小(2 × latentDim) × miniBatchSize的输出分成均值矩阵和方差矩阵,每个大小均为latentDim × batchSize。 然后,它使用这些矩阵来实现重新参数化技巧并计算编码。 最后,它将这种编码转换为SSCB格式的dlarray对象。

function [zSampled, zMean, zLogvar] = sampling(encoderNet, x)
compressed = forward(encoderNet, x);
d = size(compressed,1)/2;
zMean = compressed(1:d,:);
zLogvar = compressed(1+d:end,:);

sz = size(zMean);
epsilon = randn(sz);
sigma = exp(.5 * zLogvar);
z = epsilon .* sigma + zMean;
z = reshape(z, [1,1,sz]);
zSampled = dlarray(z, 'SSCB');
end

ELBOloss

The ELBOloss function takes the encodings of the means and the variances returned by the sampling function, and uses them to compute the ELBO loss.

ELBOloss函数采用均值和采样函数返回的方差的编码,并使用它们来计算ELBO损耗。

function elbo = ELBOloss(x, xPred, zMean, zLogvar)
squares = 0.5*(xPred-x).^2;
reconstructionLoss  = sum(squares, [1,2,3]);

KL = -.5 * sum(1 + zLogvar - zMean.^2 - exp(zLogvar), 1);

elbo = mean(reconstructionLoss + KL);
end

Visualization Functions

The VisualizeReconstruction function randomly chooses two images for each digit of the MNIST data set, passes them through the VAE, and plots the reconstruction side by side with the original input. Note that to plot the information contained inside a dlarray object, you need to extract it first using the extractdata and gather functions.

VisualizeReconstruction函数为MNIST数据集的每个数字随机选择两个图像,将它们通过VAE,然后与原始输入并排绘制。 请注意,要绘制dlarray对象中包含的信息,需要先使用extractdata and gather函数将其提取出来。

function visualizeReconstruction(XTest,YTest, encoderNet, decoderNet)
f = figure;
figure(f)
title("Example ground truth image vs. reconstructed image")
for i = 1:2
    for c=0:9
        idx = iRandomIdxOfClass(YTest,c);
        X = XTest(:,:,:,idx);

        [z, ~, ~] = sampling(encoderNet, X);
        XPred = sigmoid(forward(decoderNet, z));
        
        X = gather(extractdata(X));
        XPred = gather(extractdata(XPred));

        comparison = [X, ones(size(X,1),1), XPred];
        subplot(4,5,(i-1)*10+c+1), imshow(comparison,[]),
    end
end
end

function idx = iRandomIdxOfClass(T,c)
idx = T == categorical(c);
idx = find(idx);
idx = idx(randi(numel(idx),1));
end

VisualizeLatentSpace

The VisualizeLatentSpace function visualizes the latent space defined by the mean and the variance matrices that form the output of the encoder network, and locates the clusters formed by the latent space representations of each digit.

VisualizeLatentSpace函数可视化由形成编码器网络输出的均值和方差矩阵定义的潜在空间,并找到由每个数字的潜在空间表示形式形成的聚类。

The function starts by extracting the mean and the variance matrices from the dlarray objects. Because transposing a matrix with channel/batch dimensions (C and B) is not possible, the function calls stripdims before transposing the matrices. Then, it carries out a principal component analysis (PCA) on both matrices. To visualize the latent space in two dimensions, the function keeps the first two principal components and plots them against each other. Finally, the function colors the digit classes so that you can observe clusters.

该函数首先从dlarray对象中提取均值和方差矩阵。 由于无法转置具有通道/批处理尺寸(C和B)的矩阵,因此该函数在转置矩阵之前调用stripdims。 然后,它对两个矩阵执行主成分分析(PCA)。 为了在两个维度上可视化潜在空间,该函数保留前两个主要成分并将其相互绘制。 最后,该函数为数字类着色,以便观察群集。

function visualizeLatentSpace(XTest, YTest, encoderNet)
[~, zMean, zLogvar] = sampling(encoderNet, XTest);

zMean = stripdims(zMean)';
zMean = gather(extractdata(zMean));

zLogvar = stripdims(zLogvar)';
zLogvar = gather(extractdata(zLogvar));

[~,scoreMean] = pca(zMean);
[~,scoreLogvar] = pca(zLogvar);

c = parula(10);
f1 = figure;
figure(f1)
title("Latent space")

ah = subplot(1,2,1);
scatter(scoreMean(:,2),scoreMean(:,1),[],c(double(YTest),:));
ah.YDir = 'reverse';
axis equal
xlabel("Z_m_u(2)")
ylabel("Z_m_u(1)")
cb = colorbar; cb.Ticks = 0:(1/9):1; cb.TickLabels = string(0:9);

ah = subplot(1,2,2);
scatter(scoreLogvar(:,2),scoreLogvar(:,1),[],c(double(YTest),:));
ah.YDir = 'reverse';
xlabel("Z_v_a_r(2)")
ylabel("Z_v_a_r(1)")
cb = colorbar;  cb.Ticks = 0:(1/9):1; cb.TickLabels = string(0:9);
axis equal
end

generate

The generate function tests the generative capabilities of the VAE. It initializes a dlarray object containing 25 randomly generated encodings, passes them through the decoder network, and plots the outputs.

生成函数测试VAE的生成能力。 它初始化包含25个随机生成的编码的dlarray对象,将它们传递通过解码器网络,并绘制输出。

function generate(decoderNet, latentDim)
randomNoise = dlarray(randn(1,1,latentDim,25),'SSCB');
generatedImage = sigmoid(predict(decoderNet, randomNoise));
generatedImage = extractdata(generatedImage);

f3 = figure;
figure(f3)
imshow(imtile(generatedImage, "ThumbnailSize", [100,100]))
title("Generated samples of digits")
drawnow
end
  • 5
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
分自动编码器(Variational Autoencoder,VAE)是一种生成模型,常用于无监督学习和数据降维。它结合了自动编码器分推断的思想,可以用于生成新的样本或对数据进行重构。在Matlab中,有一些工具箱可以用于实现分自动编码器。 引用提到了一个用于分自动编码器的Copula分贝叶斯算法的Matlab代码实现。Copula是一种用于建模多量分布的方法,可以用于改进分自动编码器生成能力和数据重构能力。 引用提到了一个名为VAE_Robustness的Matlab地质反演代码,该代码实现了鲁棒性的分自动编码器。这个代码可能是针对地质数据进行分自动编码器的特定应用。 如果你想在Matlab实现分自动编码器,你可以考虑以下步骤: 1. 导入所需的Matlab工具箱,例如Deep Learning Toolbox或Statistics and Machine Learning Toolbox。 2. 定义分自动编码器的网络结构,包括编码器和解码器。编码器将输入数据映射到潜在空间中的潜在量,解码器将潜在量映射回重构的数据空间。 3. 定义损失函数,通常使用重构误差和潜在量的KL散度来衡量模型的性能。 4. 使用训练数据对分自动编码器进行训练,可以使用梯度下降等优化算法来最小化损失函数。 5. 使用训练好的模型进行生成新样本或对数据进行重构。 这只是一个简单的概述,实际实现中可能涉及到更多的细节和技巧。你可以参考引用和引用中提供的代码实现来更深入地了解如何在Matlab实现分自动编码器

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值