GAN
Constructor
头文件
/**
* The implementation of the standard GAN module. Generative Adversarial
* Networks (GANs) are a class of artificial intelligence algorithms used
* in unsupervised machine learning, implemented by a system of two neural
* networks contesting with each other in a zero-sum game framework. This
* technique can generate photographs that look at least superficially
* authentic to human observers, having many realistic characteristics.
* GANs have been used in Text-to-Image Synthesis, Medical Drug Discovery,
* High Resolution Imagery Generation, Neural Machine Translation and so on.
*
* For more information, see the following paper:
*
* @code
* @article{Goodfellow14,
* author = {Ian J. Goodfellow, Jean Pouget-Abadi, Mehdi Mirza, Bing Xu,
* David Warde-Farley, Sherjil Ozair, Aaron Courville and
* Yoshua Bengio},
* title = {Generative Adversarial Nets},
* year = {2014},
* url = {http://arxiv.org/abs/1406.2661},
* eprint = {1406.2661},
* }
* @endcode
*
* @tparam Model The class type of Generator and Discriminator.
* @tparam InitializationRuleType Type of Initializer.
* @tparam Noise The noise function to use.
* @tparam PolicyType The GAN variant to be used (GAN, DCGAN, WGAN or WGANGP).
*/
template<
typename Model,
typename InitializationRuleType,
typename Noise,
typename PolicyType = StandardGAN
>
class GAN
{
public:
/**
* Constructor for GAN class.
*
* @param generator Generator network.
* @param discriminator Discriminator network.
* @param initializeRule Initialization rule to use for initializing
* parameters.
* @param noiseFunction Function to be used for generating noise.
* @param noiseDim Dimension of noise vector to be created.
* @param batchSize Batch size to be used for training.
* @param generatorUpdateStep Number of steps to train Discriminator
* before updating Generator.
* @param preTrainSize Number of pre-training steps of Discriminator.
* @param multiplier Ratio of learning rate of Discriminator to the Generator.
* @param clippingParameter Weight range for enforcing Lipschitz constraint.
* @param lambda Parameter for setting the gradient penalty.
*/
GAN(Model generator,
Model discriminator,
InitializationRuleType& initializeRule,
Noise& noiseFunction,
const size_t noiseDim,
const size_t batchSize,
const size_t generatorUpdateStep,
const size_t preTrainSize,
const double multiplier,
const double clippingParameter = 0.01,
const double lambda = 10.0);
实现
template<
typename Model,
typename InitializationRuleType,
typename Noise,
typename PolicyType
>
GAN<Model, InitializationRuleType, Noise, PolicyType>::GAN(
Model generator,
Model discriminator,
InitializationRuleType& initializeRule,
Noise& noiseFunction,
const size_t noiseDim,
const size_t batchSize,
const size_t generatorUpdateStep,
const size_t preTrainSize,
const double multiplier,
const double clippingParameter,
const double lambda):
generator(std::move(generator)),
discriminator(std::move(discriminator)),
initializeRule(initializeRule),
noiseFunction(noiseFunction),
noiseDim(noiseDim),
numFunctions(0),
batchSize(batchSize),
currentBatch(0),
generatorUpdateStep(generatorUpdateStep),
preTrainSize(preTrainSize),
multiplier(multiplier),
clippingParameter(clippingParameter),
lambda(lambda),
reset(false),
deterministic(false),
genWeights(0),
discWeights(0)
{
// Insert IdentityLayer for joining the Generator and Discriminator.
this->discriminator.network.insert(
this->discriminator.network.begin(),
new IdentityLayer<>());
}
一般情况下,model 的 network 是一个 vector,因此,构造函数体内在 discriminator 的 network 的开始插入了一个 IdentityLayer
去看一下其实现:
IdentityLayer
/**
* Standard Identity-Layer using the identity activation function.
*/
template <
class ActivationFunction = IdentityFunction,
typename InputDataType = arma::mat,
typename OutputDataType = arma::mat
>
using IdentityLayer = BaseLayer<
ActivationFunction, InputDataType, OutputDataType>;
BaseLayer
/**
* Implementation of the base layer. The base layer works as a metaclass which
* attaches various functions to the embedding layer.
*
* A few convenience typedefs are given:
*
* - SigmoidLayer
* - IdentityLayer
* - ReLULayer
* - TanHLayer
* - SoftplusLayer
* - HardSigmoidLayer
* - SwishLayer
* - MishLayer
* - LiSHTLayer
* - GELULayer
* - ELiSHLayer
* - ElliotLayer
* - GaussianLayer
*
* @tparam ActivationFunction Activation function used for the embedding layer.
* @tparam InputDataType Type of the input data (arma::colvec, arma::mat,
* arma::sp_mat or arma::cube).
* @tparam OutputDataType Type of the output data (arma::colvec, arma::mat,
* arma::sp_mat or arma::cube).
*/
template <
class ActivationFunction = LogisticFunction,
typename InputDataType = arma::mat,
typename OutputDataType = arma::mat
>
class BaseLayer
{
public:
/**
* Create the BaseLayer object.
*/
BaseLayer()
{
// Nothing to do here.
}
/**
* Ordinary feed forward pass of a neural network, evaluating the function
* f(x) by propagating the activity forward through f.
*
* @param input Input data used for evaluating the specified function.
* @param output Resulting output activation.
*/
template<typename InputType, typename OutputType>
void Forward(const InputType& input, OutputType& output)
{
ActivationFunction::Fn(input, output);
}
/**
* Ordinary feed backward pass of a neural network, calculating the function
* f(x) by propagating x backwards trough f. Using the results from the feed
* forward pass.
*
* @param input The propagated input activation.
* @param gy The backpropagated error.
* @param g The calculated gradient.
*/
template<typename eT>
void Backward(const arma::Mat<eT>& input,
const arma::Mat<eT>& gy,
arma::Mat<eT>& g)
{
arma::Mat<eT> derivative;
ActivationFunction::Deriv(input, derivative);
g = gy % derivative;
}
//! Get the output parameter.
OutputDataType const& OutputParameter() const {
return outputParameter; }
//! Modify the output parameter.
OutputDataType& OutputParameter() {
return outputParameter; }
//! Get the delta.
OutputDataType const& Delta() const {
return delta; }
//! Modify the delta.
OutputDataType& Delta() {
return delta; }
/**
* Serialize the layer.
*/
template<typename Archive>