libtorch-char-rnn-classification

libtorch-char-rnn-classification

libtorch-UNET Carvana车辆轮廓分割
libtorch-VGG cifar10分类
libtorch-FCN CamVid语义分割
libtorch-RDN DIV2K超分重建
libtorch-char-rnn-classification libtorch官网系列教程
libtorch-char-rnn-generation libtorch官网系列教程
libtorch-char-rnn-shakespeare libtorch官网系列教程
libtorch-minst libtorch官网系列教程
————————————————

                        版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

原文链接:https://blog.csdn.net/barcaFC/article/details/140032098



前言

libtorch C++ pytorch rnn char-classification
从数据集准备、训练、推理 全部由libtorch c++完成
运行环境:
操作系统:windows 64位
opencv3.4.1
libtorch 任意版本 64位
visual studio 2017 64位编译
数据集:name
参考论文:https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html


一、rnn是什么?

循环神经网络,一种增加了隐藏状态值并且能够一对多生成或者多对一生成结果的网络。大量运用在NLP领域中,然而这种类型的网络并非只适用于NLP,如果把NLP中的语法、上下文、语义、标点符号等相关联的约束替换成另外的表达,例如切割行业中的横竖切割、所求面积尽量最大、排列的合理约束,RNN同样能够学习这些特性而生成相当具有可执行力的结果。一种基于循环神经网络的智能排板方法
对比遗传算法:
首先遗传算法在对样本遗传和变异两个阶段过于随机,并且遗传算法仅通过交叉结合和随机突变进行样本修改,生成新的样本,然后再送入评估函数进行评分,在本质上新样本的生成和分值评估两个阶段发生了断裂,二者之间并没有紧密关联,从而使得遗传算法在新样本生成方向上以及遗传变异方向更趋于随机猜测,最终迭代出优秀样本的质量和轮次无法控制和调整。
其次,遗传算法是一种传统算法,并没有记忆性,而是通过设计一种算法在需要的时候,每次进行重新计算,因此使用遗传算法在工业场景下需要难以接受的计算时间以及计算资源。
因此,如何结合现代最新的深度学习技术解决之前传统智能算法的致命不足,从而寻找到一种有效的方法解决问题,对整个行业具有着革新的影响。
冰冻三次非一日之寒,从基础开始,pytorch上的教程很详细,很经典,希望读者能够扎实的一句一句去实现和理解。

二、使用步骤

1.设置参数

int main()
{
	if (torch::cuda::is_available())
	{
		printf("torch::cuda::is_available\n");
		device_type = torch::kCUDA;
	}
	else
	{
		printf("cpu is_available\n");
		device_type = torch::kCPU;
	}
	torch::Device device(device_type);

#if _TRAIN_STAGE_		//训练

	srand((unsigned)time(NULL));

	//数据预处理
	//获取名字目录下所有的分类保存进 gAllCategorysList
	getAllCategorys(DIRECTORYPATH, gAllCategorysList);
	//将每个分类对应的lines映射进gMapCategoryAllLines
	getAllLinesForCategorys(DIRECTORYPATH);

	//获取所有分类的总数
	gn_categories = GetListElemCount(gAllCategorysList);

	//输入为所有字符总数,隐藏层节点数自定义,输出要分类的类别数量
	RNN rnn(gn_letters, n_hidden, gn_categories);

	auto optimizer = torch::optim::SGD(rnn.parameters(), torch::optim::SGDOptions(learning_rate));
	auto criterion = torch::nn::NLLLoss();

	for (int epoch = 0; epoch < n_epochs; epoch ++)
	{
		std::string category, line;
		//torch::Tensor category_tensor, line_tensor;

		std::tuple<torch::Tensor, torch::Tensor> ret = randomTrainingPair(category, line/*, category_tensor, line_tensor*/);
		torch::Tensor category_tensor = std::get<0>(ret);
		torch::Tensor line_tensor = std::get<1>(ret);
		std::tuple<torch::Tensor, torch::Tensor> retTran = train(rnn, device, category_tensor, line_tensor, optimizer, criterion);
		torch::Tensor output = std::get<0>(retTran);
		torch::Tensor loss = std::get<1>(retTran);

		//Print epoch number, loss, name and guess
		if (0 == (epoch % print_every))
		{
			std::string guss = categoryFromOutput(output);
			if (0 == guss.compare(category))
				printf("%s %s right\n", guss.c_str(), category.c_str());
			printf("loss: %.3f\n", loss.item().toFloat());
		}
	}

	printf("Finish training!\n");
	torch::serialize::OutputArchive archive;
	rnn.save(archive);
	archive.save_to("..\\retModel\\char-rnn-classification-c++.pt");
	printf("Save the training result to ..\\char-rnn-classification-c++.pt.\n");

#else		//推理

	srand((unsigned)time(NULL));

	//数据预处理
	getAllCategorys(DIRECTORYPATH, gAllCategorysList);
	gn_categories = GetListElemCount(gAllCategorysList);

	RNN		rnn(gn_letters, n_hidden, gn_categories);
	torch::serialize::InputArchive archive;
	archive.load_from("..\\retModel\\char-rnn-classification-c++.pt");

	rnn.load(archive);

	std::string line = "Hazaki";
	predict(line,3, rnn, device);
#endif // _TRAIN_STAGE_
	return 0;
}

2.数据处理(One-Hot)

//数据样本目录:../data/names
/*1:首先读出所有的名字names所属类别category(国家)		*/
/*2:然后归类出所有类别下的所有名字						*/
/*3:设计一个tensor形状为{[名字长度],[0],[所有字符数]}	*/
/*{[lens行], [0列], [对应字符one-hot]}					*/

//结构体定义
/*typedef struct ALL_CATEGORYS;		所有名字所属类别							*/
/*tpedef map CATEGORY_ALL_LINES;	所有类别下的所有名字(LINES,一行一个名字)	*/
/*要对所有字符进行unicode-ansi转换												*/
/*tensor size {[lens], [0], [all_letters]}										*/

char gszAll_Letters[] = { 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
						'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f',
						'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v',
						'w', 'x', 'y', 'z', ' ', '.', ',', ';', '\'', '-' };

int gn_letters = sizeof(gszAll_Letters);

int letterToIndex(char szLetter)
{
	int seq = 0;
	for (seq = 0; seq < sizeof(gszAll_Letters); seq++)
	{
		if (gszAll_Letters[seq] == szLetter)
			break;
	}

	if (seq == sizeof(gszAll_Letters))
		return -1;

	return seq;
}

torch::Tensor lineToTensor(const char* szLine)
{
	LONGLONG lLineLen = strlen(szLine);
	torch::Tensor tensor = torch::zeros({ lLineLen, 1, sizeof(gszAll_Letters) });

	for (int i = 0; i < lLineLen; i++)
	{
		tensor[i][0][letterToIndex(szLine[i])] = 1;
	}

#if DEBUG_PRINT
	cout << "one-hot line tensor size:\n"<< tensor.data().sizes() << endl;
	cout << "one-hot line tensor:\n" << tensor.data() << endl;
#endif

	return tensor;
}

3.RNN网络

struct RNN : public torch::nn::Module
{
	RNN(int input_size, int hidden_size, int output_size)
	{
		m_hidden_size = hidden_size;

		i2h = register_module("i2h", torch::nn::Linear(input_size + hidden_size, hidden_size));
		i2o = register_module("i2o", torch::nn::Linear(input_size + hidden_size, output_size));
		softmax = register_module("softmax", torch::nn::LogSoftmax(torch::nn::LogSoftmaxOptions(1)));
	}

	~RNN()
	{

	}

	std::tuple<torch::Tensor, torch::Tensor> forward(torch::Tensor input, torch::Tensor hidden)
	{
		//相同tensor.type才能够cat
		combined = torch::cat({ input, hidden }, 1);
		hidden = i2h(combined);
		output = i2o(combined);
		output = softmax(output);
		//output = torch::log_softmax(output, 1);	后备方案,此句可用
		return { output, hidden };
	}

	torch::Tensor initHidden()
	{
		return torch::zeros({1, m_hidden_size });
	}

	int m_hidden_size = 0;

	torch::nn::Linear i2h{ nullptr }, i2o{ nullptr };
	torch::nn::LogSoftmax softmax = nullptr;
	torch::Tensor combined, output;
};

#endif // _MODEL_H

4.训练

std::tuple<torch::Tensor, torch::Tensor>  train(RNN& rnn, torch::Device device,
												torch::Tensor category_tensor, torch::Tensor line_tensor,
												torch::optim::Optimizer& optimizer, torch::nn::NLLLoss& criterion)
{
	rnn.to(device);
	rnn.train(true);

	torch::Tensor hidden = rnn.initHidden();
	optimizer.zero_grad();

	int nline_tensor_size = 0;
	nline_tensor_size = line_tensor.sizes()[0];

	torch::Tensor output;

	for (int i = 0; i < nline_tensor_size; i++)
	{
		std::tuple<torch::Tensor, torch::Tensor> ret = rnn.forward(line_tensor[i], hidden);
		output = std::get<0>(ret);
		hidden = std::get<1>(ret);
	}

	cout << "out sizes : " << output.sizes() << endl;
	cout << "category_tensor sizes : " << category_tensor.sizes() << endl;
	auto loss = criterion(output, category_tensor);
	loss.backward();

	optimizer.step();

	return { output, loss };

}

5.推理

torch::Tensor evaluate(RNN& rnn, torch::Device device, torch::Tensor line_tensor)
{
	rnn.to(device);
	rnn.train(false);

	torch::Tensor output, hidden;

	hidden = rnn.initHidden();

	int nline_tensor_size = 0;
	nline_tensor_size = line_tensor.sizes()[0];

	for (int i = 0; i < nline_tensor_size; i++)
	{
		std::tuple<torch::Tensor, torch::Tensor> ret = rnn.forward(line_tensor[i], hidden);
		output = std::get<0>(ret);
		hidden = std::get<1>(ret);
	}

	return output;
}

void predict(std::string line, int n_predictions, RNN& rnn, torch::Device device)
{
	torch::Tensor output = evaluate(rnn, device, lineToTensor(line.c_str()));

	std::tuple<torch::Tensor, torch::Tensor> topRet = output.topk(n_predictions, 1, true);

	torch::Tensor topv = std::get<0>(topRet);
	torch::Tensor topi = std::get<1>(topRet);

	for (int i = 0; i < n_predictions; i++)
	{
		float value = topv[0][i].item<float>();
		int category_index = topi[0][i].item<int>();
		printf("(%.2f) %s\n", value, FindListElem(gAllCategorysList, category_index).c_str());
	}
}

总结

区别于像素点RGB三色推理,RNN提供了基于ONE-HOT或者其他相似样本编码类型的迭代解析解,它能够推断某种排列组合,而且相比于传统的遗传、蚁群、退火等算法,更加的先进合理,有迹可寻。

  • 31
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值