libtorch-char-rnn-shakespeare

libtorch-char-rnn-shakespeare

libtorch-UNET Carvana车辆轮廓分割
libtorch-VGG cifar10分类
libtorch-FCN CamVid语义分割
libtorch-RDN DIV2K超分重建
libtorch-char-rnn-classification libtorch官网系列教程
libtorch-char-rnn-generation libtorch官网系列教程
libtorch-char-rnn-shakespeare libtorch官网系列教程
libtorch-minst libtorch官网系列教程
————————————————

                        版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

原文链接:https://blog.csdn.net/barcaFC/article/details/140116269



前言


延续pytorch-rnn官网教程,本文为libtorch教程上的文章内容生成,使用的是莎士比亚文章内容进行学习和生成,可以自由修改成看图写话或者小短文的生成。
今年我国在内容生成上的专利申请数全球第一。
内容生成可以生成数据、文字、图片、视频、音乐等等。然而,再重复一次,解决问题不是要全球领先,而是根据实际问题选用合适的模型进行改动。

libtorch C++ libtorch-char-rnn-shakespeare
从数据集准备、训练、推理 全部由libtorch c++完成
运行环境:
操作系统:windows 64位
opencv3.4.1
libtorch 任意版本 64位
visual stdio 2017 64位编译
关联参考:
https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutorial.html
https://github.com/spro/practical-pytorch/tree/master/char-rnn-generation
RNN_KERNEL_VERSION 变为2.0版本,使用encoder-decoder结构、WordEmbedding、rnn-GRU
因为生成对话需要较长距离的上下文内容
系列中data.hpp文件跟随官网改变成helpers.hpp
read_file函数读取原始文件内容由前两版ReadContentFromFiles函数修改而来
————————————————

一、rnn-shakespeare是什么?

和基于字符级生成不同,本文使用WordEmbedding,对词进行了编码使得能够词生成,而且结构上具有了后续seq2seq的核心模型encoder-decoder。而模型演示的encoder-decoder也很简单但实现的功能确是很强大的。
无论是RDN进行图片重建或者是RNN内容生成,均具有前后关联(跳跃链接、隐藏层)以及encoder-decoder结构。

二、使用步骤

1.设置参数

int main(int argc, char **argv)
{
	if (torch::cuda::is_available())
	{
		printf("torch::cuda::is_available\n");
		device_type = torch::kCUDA;
	}
	else
	{
		printf("cpu is_available\n");
		device_type = torch::kCPU;
	}
	torch::Device device(device_type);

#if _TRAIN_STAGE_		//训练
	srand((unsigned)time(NULL));

	read_file(L"..\\data\\shakespeare.txt");

	RNN decoder(n_characters, hidden_size, n_characters, n_layers);

	auto decoder_optimizer = torch::optim::Adam(decoder.parameters(), torch::optim::AdamOptions(learning_rate));
	auto criterion = torch::nn::CrossEntropyLoss();

	printf("Training for %d epochs...", n_epochs);

	//float loss_avg = 0;

	for (int epoch = 0; epoch < n_epochs; epoch++)
	{
		std::tuple<torch::Tensor, torch::Tensor> retRandomSet = random_training_set(chunk_len);
		torch::Tensor inp = std::get<0>(retRandomSet);
		torch::Tensor target = std::get<1>(retRandomSet);
		//printf("random_training_set : %d %d\n", inp.sizes(), target.sizes());
		//cout << endl << "inp.sizes()::" << inp.sizes() << "target.sizes()::" << target.sizes() << endl;

		torch::Tensor loss = train(decoder, device, inp, target, decoder_optimizer, criterion);
		//printf("train : %.4f\n", loss.item().toFloat());
		//loss_avg += (loss.item().toFloat() / chunk_len);
		//Print epoch number, loss, name and guess
		if (0 == (epoch % print_every))
		{
			printf("(%d %d%%) loss: %.4f\n", epoch, epoch / n_epochs, (loss.item().toFloat() / chunk_len));
			//loss_avg = 0;
		}
	}

	printf("Finish training!\n");
	torch::serialize::OutputArchive archive;
	decoder.save(archive);
	archive.save_to("..\\retModel\\char-rnn-shakespeare-c++.pt");
	printf("Save the training result to ..\\char-rnn-shakespeare-c++.pt.\n");

#else		//推理
	RNN decoder(n_characters, hidden_size, n_characters, n_layers);

	torch::serialize::InputArchive archive;
	archive.load_from("..\\retModel\\char-rnn-shakespeare-c++.pt");

	decoder.load(archive);

	std::string prime_str = "Anger";
	cout << endl << generate(prime_str, decoder, device, 200, 0.8) << endl;

#endif;
	return 0;
}

2.网络结构

struct RNN : public torch::nn::Module
{
	//n_characters, hidden_size, n_characters, n_layers
	RNN(int input_size, int hidden_size, int output_size, int n_layers)
	{
		m_input_size = input_size;
		m_hidden_size = hidden_size;
		m_output_size = output_size;
		m_n_layers = n_layers;

		encoder = register_module("encoder", torch::nn::Embedding(input_size, hidden_size));
		gru = register_module("gru", torch::nn::GRU(torch::nn::GRUOptions(hidden_size, hidden_size).num_layers(n_layers)));
		decoder = register_module("decoder", torch::nn::Linear(hidden_size, output_size));
	}
	~RNN()
	{

	}

	std::tuple<torch::Tensor, torch::Tensor> forward(torch::Tensor input, torch::Tensor hidden)
	{
		m_input = encoder(input.view({ 1, -1 }));
		std::tuple<torch::Tensor, torch::Tensor> retgru = gru(m_input.view({ 1, 1, -1 }), hidden);
		m_output = std::get<0>(retgru);
		m_hidden = std::get<1>(retgru);
		m_output = decoder(m_output.view({ 1, -1 }));
		return { m_output, m_hidden };
	}

	torch::Tensor initHidden()
	{
		return torch::zeros({ m_n_layers, 1, m_hidden_size });
	}

	int m_n_layers = 0;
	int m_input_size = 0;
	int m_hidden_size = 0;
	int m_output_size = 0;

	torch::nn::Embedding encoder{ nullptr };
	torch::nn::GRU gru{ nullptr };
	torch::nn::Linear decoder{ nullptr };
	torch::Tensor m_input, m_output, m_hidden;
};

3.训练

/*作者:barcaFC											*/
/*日期:2022-11-10										*/
/*函数:random_training_set								*/
/*功能:inp从0-倒数第二字符 target从1到最后一个字符		*/
/*输入:chunk_len要读取多少字符							*/
/*输出:输出inp和target的tensor	送入wordEmbedding		*/

std::tuple<torch::Tensor, torch::Tensor> random_training_set(int chunk_len)
{
	torch::Tensor inp, target;

	int start_index = rand() % gfilelen;
	int end_index = start_index + chunk_len + 1;

	std::string chunk = gfile.substr(start_index, end_index);

	inp = char_tensor(chunk.substr(start_index, chunk_len));
	target = char_tensor(chunk.substr(start_index + 1, chunk_len + 1));

	return { inp ,target };
}

torch::Tensor train(RNN& decoder, torch::Device device,
	torch::Tensor inp, torch::Tensor target,
	torch::optim::Optimizer& decoder_optimizer, torch::nn::CrossEntropyLoss& criterion)
{
	decoder.to(device);
	decoder.train(true);

	torch::Tensor hidden = decoder.initHidden();
	decoder_optimizer.zero_grad();

	torch::Tensor loss = torch::zeros({ 1 }, torch::kFloat64);
	torch::Tensor output;
	for (int c = 0; c < chunk_len; c++)
	{
		//cout << "inp[c].data.size : " << endl;
		//cout << inp[c].data().sizes() << endl;
		std::tuple<torch::Tensor, torch::Tensor> retTrain = decoder.forward(inp[c], hidden);
		output = std::get<0>(retTrain);
		hidden = std::get<1>(retTrain);

		try
		{
			torch::Tensor temp = torch::zeros({ 1 }, torch::kLong);
			temp[0] = target[c].item().toLong();
			loss += criterion(output, temp);
		}
		catch (const c10::Error& e)
		{
			std::cout << "criterion an error has occured : " << e.msg() << std::endl;
			return loss;
		}
	}

	loss.backward();
	decoder_optimizer.step();

	return loss;
}

4.生成

std::string generate(std::string prime_str, RNN& decoder, torch::Device device, int predict_len, float temperature)
{
	decoder.to(device);
	decoder.train(false);

	torch::Tensor hidden = decoder.initHidden();
	torch::Tensor prime_input = char_tensor(prime_str);

	std::string predicted = prime_str;

	//Use priming string to "build up" hidden state
	for (int p = 0; p < prime_str.length(); p++)
	{
		std::tuple<torch::Tensor, torch::Tensor> retHidden = decoder.forward(prime_input[p], hidden);
		hidden = std::get<1>(retHidden);
	}

	torch::Tensor inp = prime_input[prime_str.length() - 1];
	torch::Tensor output;
	for (int c = 0; c < predict_len; c++)
	{
		std::tuple<torch::Tensor, torch::Tensor> retDecoder = decoder.forward(inp, hidden);
		output = std::get<0>(retDecoder);
		hidden = std::get<1>(retDecoder);

		//Sample from the network as a multinomial distribution
		torch::Tensor output_dist = output.data().view(-1).div(temperature).exp();
		torch::Tensor top_i = torch::multinomial(output_dist, 1)[0];

		//Add predicted character to string and use as next input
		char predicted_char = all_characters.at(top_i.item().toInt());
		predicted += predicted_char;
		std::string strINP;
		strINP = predicted_char;
		inp = char_tensor(strINP);
	}

	return predicted;
}

5.辅助函数

std::string all_characters = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+, -./:;<=>?@[\]^_`{|}~";
int n_characters = all_characters.length();

int gfilelen = 0;
std::string gfile;

BOOL read_file(const wchar_t * pwszFilePath)
{
	ifstream fin;
	std::string str;
	fin.open(pwszFilePath, ios::in);

	fin.seekg(0, ios::end);
	gfilelen = fin.tellg();
	fin.seekg(0, ios::beg);

	if (!fin.is_open())
	{
		cout << "无法找到这个文件!" << endl;
		return FALSE;
	}

	while (getline(fin, str))
	{
		gfile += str;
	}
	
	fin.close();

	return TRUE;
}

torch::Tensor char_tensor(std::string str)
{
	LONGLONG len = str.length();
	//cout << "str : " << str << endl;

	torch::Tensor tensor = torch::zeros({ len }, torch::kLong);
	for (int c = 0; c < len; c++)
		tensor[c] = (LONG)all_characters.find(str[c]);

	//cout << "char_tensor sizes===>:" << endl;
	//cout << tensor.sizes() << endl;
	//cout << tensor << endl;
	return tensor;
}

总结

内容生成(AIGC)是目前炙手可热的AI分支,内容生成不仅仅是能够生成训练内容,也是一种迭代逼近某种解析解(具有无限解空间的最优解)的方法。将内容生成和其他提取类模型相结合,能够辅助完成非常多之前需要大量人力物力的工作。

  • 0
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值