libtorch c++复现gan网络

目录

1. 原论文

2. libtorch复现

2.1 生成器

2.2 判别器

2.3 完整训练代码

2.4 训练效果


1. 原论文

论文:https://arxiv.org/pdf/1406.2661.pdf

pytorch源码:https://github.com/devnag/pytorch-generative-adversarial-networks/blob/master/gan_pytorch.py

 详细介绍可参考:pytorch实现GAN_Mr.Q的博客-CSDN博客_pytorch实现gan

2. libtorch复现

2.1 生成器

生成器是全链接网络,输入是正态分布的随机数,size是(b,100),5次全链接层后得到(b,784),再view成(b,1,28,28). 由于中间有个tanh正切激活函数,所以输出值的范围在(-1,1)之间。

// ln+bn+relu
class LnBnReluImpl : public torch::nn::Module {
public: 
	LnBnReluImpl(int64_t in_c, int64_t out_c, bool normalize);
	torch::Tensor forward(torch::Tensor x);
private:
	torch::nn::Linear ln = nullptr;
	torch::nn::BatchNorm1d bn = nullptr;
	torch::nn::LeakyReLU LReLU = nullptr;
	bool normalize = true;
};
TORCH_MODULE(LnBnRelu);
LnBnReluImpl::LnBnReluImpl(int64_t in_c, int64_t out_c, bool normalize)
{
	this->normalize = normalize;
	ln = torch::nn::Linear(in_c, out_c);
	if (normalize)
		bn = torch::nn::BatchNorm1d(torch::nn::BatchNorm1dOptions(out_c).eps(0.8));  // epsilon=0.8
	LReLU = torch::nn::LeakyReLU(torch::nn::LeakyReLUOptions().negative_slope(0.2).inplace(true));  // inplace=true
	
	register_module("block ln", ln);
	if (normalize)
		register_module("block bn", bn);
}
torch::Tensor LnBnReluImpl::forward(torch::Tensor x)
{
	x = ln->forward(x);
	if (normalize)
		x = bn(x);
	x = LReLU->forward(x);
	return x;
}

// 全链接网络:4个ln+bn+relu, 再接ln, tanh
class GeneratorImpl : public torch::nn::Module {
public:
	GeneratorImpl();
	torch::Tensor forward(torch::Tensor x);
private:
	LnBnRelu fc1 = nullptr;
	LnBnRelu fc2 = nullptr;
	LnBnRelu fc3 = nullptr;
	LnBnRelu fc4 = nullptr;
	torch::nn::Linear fc5{ nullptr };
};
TORCH_MODULE(Generator);

GeneratorImpl::GeneratorImpl() {
	fc1 = LnBnRelu(NOISE_SIZE, 128, false);
	fc2 = LnBnRelu(128, 256, true);
	fc3 = LnBnRelu(256, 512, true);
	fc4 = LnBnRelu(512, 1024, true);
	fc5 = torch::nn::Linear(1024, int(IMAGE_SHAPE[0] * IMAGE_SHAPE[1] * IMAGE_SHAPE[2]));
	register_module("generator fc1", fc1);
	register_module("generator fc2", fc2);
	register_module("generator fc3", fc3);
	register_module("generator fc4", fc4);
	register_module("generator fc5", fc5);
};
torch::Tensor GeneratorImpl::forward(torch::Tensor x)  // (b,100)
{
	x = fc1(x);
	x = fc2(x);
	x = fc3(x);
	x = fc4(x);
	x = fc5(x);
	x = torch::tanh(x);  // (-1,1)
	x = x.view({ x.sizes()[0], IMAGE_SHAPE[0], IMAGE_SHAPE[1], IMAGE_SHAPE[2] });  // (b,1,28,28)
	return x;
}

2.2 判别器

判别器也是全链接网络,输入的数据是(b,1,28,28)大小,先view成(b,784)向量,再经过3次全连接层,得到size(b,1)值 ,最后经过sigmoid输出分数值在(0,1)之间。值越接近1说明输入的越真。

// 全链接网络:两个ln+relu, 再接一个ln,最后一个sigmoid分类。
class DiscriminatorImpl : public torch::nn::Module {
public:
	DiscriminatorImpl();
	torch::Tensor forward(torch::Tensor x);
private:
	torch::nn::Linear fc1{ nullptr };
	torch::nn::Linear fc2{ nullptr };
	torch::nn::Linear fc3{ nullptr };
	torch::nn::LeakyReLU relu1{ nullptr };
	torch::nn::LeakyReLU relu2{ nullptr };
};
TORCH_MODULE(Discriminator);

DiscriminatorImpl::DiscriminatorImpl() {
	fc1 = torch::nn::Linear(IMAGE_SHAPE[0] * IMAGE_SHAPE[1] * IMAGE_SHAPE[2], 512);
	relu1 = torch::nn::LeakyReLU(torch::nn::LeakyReLUOptions().negative_slope(0.2).inplace(true));
	fc2 = torch::nn::Linear(512, 256);
	relu2 = torch::nn::LeakyReLU(torch::nn::LeakyReLUOptions().negative_slope(0.2).inplace(true));
	fc3 = torch::nn::Linear(256, 1);
	register_module("disciminator fc1", fc1);
	register_module("disciminator fc2", fc2);
	register_module("disciminator fc3", fc3);
}
torch::Tensor DiscriminatorImpl::forward(torch::Tensor x)  // (b,1,28,28)
{
	x = x.view({ x.sizes()[0], -1 });  // (b,784)
	x = fc1(x);
	x = relu1(x);
	x = fc2(x);
	x = relu2(x);
	torch::Tensor validity = fc3(x);  // (b,1)
	validity = torch::sigmoid(validity);  // (0,1)
	return validity;
}

2.3 完整训练代码

这里是基于minist数据,生成手写字体。后面将探索实现生成人脸数据。

#include <torch/torch.h>
#include <torch/script.h>

#include <iostream>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/types_c.h>

const std::string DATA_FOLDER = "./data/MNIST/raw";
const int64_t BATCH_SIZE = 64;
const int64_t N_EPOCHS = 200;
const int64_t NOISE_SIZE = 100;
torch::Device device(torch::kCPU);
std::vector<int> IMAGE_SHAPE{ 1, 28, 28 };

// ln+bn+relu
class LnBnReluImpl : public torch::nn::Module {
public: 
	LnBnReluImpl(int64_t in_c, int64_t out_c, bool normalize);
	torch::Tensor forward(torch::Tensor x);
private:
	torch::nn::Linear ln = nullptr;
	torch::nn::BatchNorm1d bn = nullptr;
	torch::nn::LeakyReLU LReLU = nullptr;
	bool normalize = true;
};
TORCH_MODULE(LnBnRelu);
LnBnReluImpl::LnBnReluImpl(int64_t in_c, int64_t out_c, bool normalize)
{
	this->normalize = normalize;
	ln = torch::nn::Linear(in_c, out_c);
	if (normalize)
		bn = torch::nn::BatchNorm1d(torch::nn::BatchNorm1dOptions(out_c).eps(0.8));  // epsilon=0.8
	LReLU = torch::nn::LeakyReLU(torch::nn::LeakyReLUOptions().negative_slope(0.2).inplace(true));  // inplace=true
	
	register_module("block ln", ln);
	if (normalize)
		register_module("block bn", bn);
}
torch::Tensor LnBnReluImpl::forward(torch::Tensor x)
{
	x = ln->forward(x);
	if (normalize)
		x = bn(x);
	x = LReLU->forward(x);
	return x;
}

// 全链接网络:4个ln+bn+relu, 再接ln, tanh
class GeneratorImpl : public torch::nn::Module {
public:
	GeneratorImpl();
	torch::Tensor forward(torch::Tensor x);
private:
	LnBnRelu fc1 = nullptr;
	LnBnRelu fc2 = nullptr;
	LnBnRelu fc3 = nullptr;
	LnBnRelu fc4 = nullptr;
	torch::nn::Linear fc5{ nullptr };
};
TORCH_MODULE(Generator);

GeneratorImpl::GeneratorImpl() {
	fc1 = LnBnRelu(NOISE_SIZE, 128, false);
	fc2 = LnBnRelu(128, 256, true);
	fc3 = LnBnRelu(256, 512, true);
	fc4 = LnBnRelu(512, 1024, true);
	fc5 = torch::nn::Linear(1024, int(IMAGE_SHAPE[0] * IMAGE_SHAPE[1] * IMAGE_SHAPE[2]));
	register_module("generator fc1", fc1);
	register_module("generator fc2", fc2);
	register_module("generator fc3", fc3);
	register_module("generator fc4", fc4);
	register_module("generator fc5", fc5);
};
torch::Tensor GeneratorImpl::forward(torch::Tensor x)  // (64,100)
{
	x = fc1(x);
	x = fc2(x);
	x = fc3(x);
	x = fc4(x);
	x = fc5(x);
	x = torch::tanh(x);  // (-1,1)
	x = x.view({ x.sizes()[0], IMAGE_SHAPE[0], IMAGE_SHAPE[1], IMAGE_SHAPE[2] });  // (64,1,28,28)
	return x;
}

// 全链接网络:两个ln+relu, 再接一个ln,最后一个sigmoid分类。
class DiscriminatorImpl : public torch::nn::Module {
public:
	DiscriminatorImpl();
	torch::Tensor forward(torch::Tensor x);
private:
	torch::nn::Linear fc1{ nullptr };
	torch::nn::Linear fc2{ nullptr };
	torch::nn::Linear fc3{ nullptr };
	torch::nn::LeakyReLU relu1{ nullptr };
	torch::nn::LeakyReLU relu2{ nullptr };
};
TORCH_MODULE(Discriminator);

DiscriminatorImpl::DiscriminatorImpl() {
	fc1 = torch::nn::Linear(IMAGE_SHAPE[0] * IMAGE_SHAPE[1] * IMAGE_SHAPE[2], 512);
	relu1 = torch::nn::LeakyReLU(torch::nn::LeakyReLUOptions().negative_slope(0.2).inplace(true));
	fc2 = torch::nn::Linear(512, 256);
	relu2 = torch::nn::LeakyReLU(torch::nn::LeakyReLUOptions().negative_slope(0.2).inplace(true));
	fc3 = torch::nn::Linear(256, 1);
	register_module("disciminator fc1", fc1);
	register_module("disciminator fc2", fc2);
	register_module("disciminator fc3", fc3);
}
torch::Tensor DiscriminatorImpl::forward(torch::Tensor x)  // (64,1,28,28)
{
	x = x.view({ x.sizes()[0], -1 });  // (64,784)
	x = fc1(x);
	x = relu1(x);
	x = fc2(x);
	x = relu2(x);
	torch::Tensor validity = fc3(x);  // (64,1)
	validity = torch::sigmoid(validity);  // (0,1)
	return validity;
}


void Visualize(const torch::Tensor& samples)
{
	int n = 10;
	cv::Mat scene(cv::Size(samples.sizes()[2] * n, samples.sizes()[3]), CV_32F);  // witdh = w*n

	for (int i = 0; i < n; i++)
	{
		auto image_tensor = samples[i].detach().cpu();  // (1,28,28)
		cv::Mat image_mat(image_tensor.size(1), image_tensor.size(2), CV_32F, image_tensor.data_ptr());  // tensor to mat
		image_mat.copyTo(scene(cv::Rect(image_mat.cols * i, 0, image_mat.cols, image_mat.rows)));  // x = w = col*i
	}
	cv::namedWindow("visualize", cv::WINDOW_NORMAL);
	cv::imshow("visualize", scene);
	cv::waitKey(1);
}

int main()
{
	if (torch::cuda::is_available())
		device = torch::Device(torch::kCUDA);

	// Assume the MNIST dataset is available under `kDataFolder`;
	auto dataset = torch::data::datasets::MNIST(DATA_FOLDER)  // http://yann.lecun.com/exdb/mnist.
		.map(torch::data::transforms::Normalize<>(0.5, 0.5))  // 0.5均值,0.5方差
		.map(torch::data::transforms::Stack<>());
	const int64_t batches_per_epoch = std::ceil(dataset.size().value() / static_cast<double>(BATCH_SIZE));

	auto options = torch::data::DataLoaderOptions();
	options.drop_last(true);
	options.batch_size(BATCH_SIZE);
	options.workers(2);
	auto data_loader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(std::move(dataset), options);

	//auto data_loader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions().batch_size(BATCH_SIZE).workers(2));

	// Initialize generator and discriminator
	Generator generator = Generator();
	Discriminator discriminator = Discriminator();
	generator->to(device);
	discriminator->to(device);
	// Loss function
	torch::nn::BCELoss adversarial_loss = torch::nn::BCELoss();
	adversarial_loss->to(device);
	// optimizers
	;
	torch::optim::Adam optimizer_G = torch::optim::Adam(generator->parameters(), torch::optim::AdamOptions(0.002).betas(std::make_tuple(0.5, 0.999)));
	torch::optim::Adam optimizer_D = torch::optim::Adam(discriminator->parameters(), torch::optim::AdamOptions(0.002).betas(std::make_tuple(0.5, 0.999)));

	for (int64_t epoch = 1; epoch <= N_EPOCHS; epoch++)
	{
		int64_t batch_index = 0;
		for (torch::data::Example<>& batch : *data_loader)
		{
			// Adversarial ground truths
			torch::Tensor valid = torch::ones({ batch.data.size(0), 1 }, torch::kFloat).to(device);  // (64,1)
			torch::Tensor fake = torch::zeros({ batch.data.size(0), 1 }, torch::kFloat).to(device);  // (64,1)
			
			torch::Tensor real_imges = batch.data.to(device);  // (64,1,28,28)

			/*  
			-----------------
			Train Generator
			-----------------
			*/
			optimizer_G.zero_grad();
			// Sample noise as generator input
			torch::Tensor z = torch::randn({ batch.data.size(0), NOISE_SIZE}, device);  // (64,100)
			// Generate a batch of images
			torch::Tensor gen_imges = generator(z);  // (64,1,28,28)
			// std::cout << gen_imges.sizes() << std::endl;
			// Loss measures generator's ability to fool the discriminator
			torch::Tensor g_loss = adversarial_loss(discriminator(gen_imges), valid);  // 越真越好

			g_loss.backward();  // 生成器loss,负责越真越好
			optimizer_G.step();

			/*
			---------------------
			Train Discriminator
			---------------------
			*/
			optimizer_D.zero_grad();
			//Measure discriminator's ability to classify real from generated samples
			torch::Tensor real_loss = adversarial_loss(discriminator(real_imges), valid);  // 能够判别真的
			torch::Tensor fake_loss = adversarial_loss(discriminator(gen_imges.detach()), fake);  // 也能够判别假的, 注意这里要有detach,否则会报错runtime error. datch之后这里只会更新判别器梯度,不会更新生成器梯度。
			torch::Tensor d_loss = (real_loss + fake_loss) / 2;
			d_loss.backward();  // 判别器loss,甄别能力越强越好
			optimizer_D.step();

			std::cout << "Epoch: " << epoch << "/" << N_EPOCHS << 
				"- D loss: " << d_loss.item().toDouble() << ", G loss: " << g_loss.item().toDouble() << std::endl;

			Visualize(gen_imges);
			
		}
	}
	return 0;
}

2.4 训练效果

epoch 10:

 epoch 80:

 

 

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 10
    评论
libtorch是一个用于C++的开源机器学习库,它是PyTorch框架的C++前端,可以在不依赖Python环境的情况下使用PyTorch的功能。它提供了一些用于构建、训练和部署深度学习模型的工具和接口。 MobileNet是一种轻量级的卷积神经网络模型,专门设计用于在移动设备和嵌入式系统上进行实时图像分类和目标检测。MobileNet_v2是对MobileNet的改进版本,通过引入更多的深度可分离卷积和倒残差结构,提高了模型的性能和效率。 使用libtorch和MobileNet_v2,我们可以在C++环境中构建、训练和部署目标检测或图像分类模型。首先,我们可以使用libtorch提供的工具将MobileNet_v2的模型定义加载到C++程序中。然后,我们可以使用该模型进行推理,对输入图像进行分类或目标检测,并获取相应的输出结果。 在使用libtorch和MobileNet_v2时,需要注意以下几点:首先,我们需要确保在环境中正确配置了libtorch库,并将其链接到我们的C++程序中。其次,我们可以根据具体的任务需求,使用MobileNet_v2的预训练模型或根据自己的数据集进行训练和微调。最后,我们可以使用libtorch提供的接口进行模型的推理和结果的处理。 总而言之,libtorch和MobileNet_v2的结合可以提供一个在C++环境中进行目标检测和图像分类的解决方案,使得我们可以在移动设备或嵌入式系统中部署高性能且轻量级的深度学习模型。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Mr.Q

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值