libtorch-char-rnn-generation

libtorch-char-rnn-generation

libtorch-UNET Carvana车辆轮廓分割
libtorch-VGG cifar10分类
libtorch-FCN CamVid语义分割
libtorch-RDN DIV2K超分重建
libtorch-char-rnn-classification libtorch官网系列教程
libtorch-char-rnn-generation libtorch官网系列教程
libtorch-char-rnn-shakespeare libtorch官网系列教程
libtorch-minst libtorch官网系列教程
————————————————

                        版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。

原文链接:https://blog.csdn.net/barcaFC/article/details/140090906



前言

延续pytorch-rnn官网教程,字符级生成,无论是后来的transformer或者大模型chat-gpt都是由最小的rnn模型构建和演变而成,其中最为核心的梯度下降更是贯穿所有现代深度学习的框架。

libtorch C++ libtorch-char-rnn-generation
从数据集准备、训练、推理 全部由libtorch c++完成
运行环境:
操作系统:windows 64位
opencv3.4.1
libtorch 任意版本 64位
visual stdio 2017 64位编译
相关数据集 names
参考论文 https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutorial.html
————————————————


一、rnn-char-generation是什么?

通过某种特性推断属于哪个种类,例如本文中通过国籍推断人名或者通过人名推断国籍。同时这也是生成文章和对话的基础。

二、使用步骤

1.参数设置

代码如下(示例):

int main(int argc, char **argv)
{
	if (torch::cuda::is_available())
	{
		printf("torch::cuda::is_available\n");
		device_type = torch::kCUDA;
	}
	else
	{
		printf("cpu is_available\n");
		device_type = torch::kCPU;
	}
	torch::Device device(device_type);

#if _TRAIN_STAGE_		//训练

	//Preparing for Training
	srand((unsigned)time(NULL));

	//数据预处理
	read_categorys(DIRECTORYPATH, gAllCategorysList);
	read_lines(DIRECTORYPATH);

	gn_categories = GetListElemCount(gAllCategorysList);

	RNN rnn(gn_categories, gn_letters, hidden_size, gn_letters);

	auto optimizer = torch::optim::Adam(rnn.parameters(), torch::optim::AdamOptions(learning_rate));
	auto criterion = torch::nn::CrossEntropyLoss();

	for (int epoch = 0; epoch < n_epochs; epoch++)
	{
		std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> ret = randomTrainingPair();
		torch::Tensor category_tensor = std::get<0>(ret);
		torch::Tensor line_tensor = std::get<1>(ret);
		torch::Tensor target_tensor = std::get<2>(ret);
		int nline_tensor_size = 0;
		nline_tensor_size = line_tensor.sizes()[0];
		std::tuple<torch::Tensor, torch::Tensor> retTran = train(rnn, device, category_tensor, line_tensor, target_tensor, optimizer, criterion);
		torch::Tensor output = std::get<0>(retTran);
		torch::Tensor loss = std::get<1>(retTran);
		loss_avg += (loss.item().toFloat() / nline_tensor_size);
		//Print epoch number, loss, name and guess
		if (0 == (epoch % print_every))
		{
			printf("(%d %d%%) loss: %.4f\n", epoch, epoch / n_epochs * 100, (loss.item().toFloat() / nline_tensor_size));
			loss_avg = 0;
		}
	}

	printf("Finish training!\n");
	torch::serialize::OutputArchive archive;
	rnn.save(archive);
	archive.save_to("..\\retModel\\char-rnn-generation-c++.pt");
	printf("Save the training result to ..\\char-rnn-generation-c++.pt.\n");

#else		//推理

	std::string strCategory = "Chinese";
	std::string strStart_Chars = "CHI";

	//数据预处理
	read_categorys(DIRECTORYPATH, gAllCategorysList);
	read_lines(DIRECTORYPATH);

	gn_categories = GetListElemCount(gAllCategorysList);

	RNN rnn(gn_categories, gn_letters, hidden_size, gn_letters);

	torch::serialize::InputArchive archive;
	archive.load_from("..\\retModel\\char-rnn-generation-c++.pt");

	rnn.load(archive);

	generate(strCategory, strStart_Chars, rnn, device);

#endif

	return 0;
}

2.数据处理

代码如下(示例):

//Preparing the Data
char gszAll_Letters[] = { 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
						'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f',
						'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v',
						'w', 'x', 'y', 'z', ' ', '.', ',', ';', '\'', '-' };

int gn_letters = sizeof(gszAll_Letters) + 1;

int EOS = gn_letters - 1;

int letterToIndex(char szLetter)
{
	int seq = 0;
	for (seq = 0; seq < sizeof(gszAll_Letters); seq++)
	{
		if (gszAll_Letters[seq] == szLetter)
			break;
	}

	if (seq == sizeof(gszAll_Letters))
		return -1;

	return seq;
}

/*作者:barcaFC											*/
/*日期:2022-11-1										*/
/*函数:lineToTensor										*/
/*功能:将line以one-hot的方式转换进tensor				*/
/*<line_length x 1 x n_letters>							*/
/*输入:line												*/
/*输出:line对应的one-hot tensor							*/

torch::Tensor lineToTensor(const char* szLine)
{
	LONGLONG lLineLen = strlen(szLine);
	torch::Tensor tensor = torch::zeros({ lLineLen, 1, gn_letters });

	for (int i = 0; i < lLineLen; i++)
	{
		tensor[i][0][letterToIndex(szLine[i])] = 1;
	}

#if DEBUG_PRINT
	cout << "one-hot line tensor size:\n" << tensor.data().sizes() << endl;
	cout << "one-hot line tensor:\n" << tensor.data() << endl;
#endif

	return tensor;
}

torch::Tensor categoryToTensor(const char* szCategory)
{
	std::string		strCategory = szCategory;
	int li = FindListElemSeq(gAllCategorysList, strCategory);
	torch::Tensor tensor = torch::zeros({ 1, gn_categories });
	tensor[0][li] = 1;

	return tensor;
}

torch::Tensor targetToTensor(const char* szTarget)
{
	LONGLONG lLineLen = strlen(szTarget);

	torch::Tensor tensor = torch::zeros({ lLineLen }, torch::kLong);

	for (int i = 1; i < lLineLen; i++)
	{
		tensor[i - 1] = letterToIndex(szTarget[i]);
	}
	tensor[lLineLen-1] = EOS;

	return tensor;
}



int read_categorys(wchar_t *pwszDirectoryPath, allCategorysList_t &List)
{
	EnumerateFileInDirectory(pwszDirectoryPath);

	enumerateFilesList_t::iterator iter;
	for (iter = enumerateFilesList.begin(); iter != enumerateFilesList.end(); iter++)
	{
		int sublen = iter->strCommonType.find(".");
		std::string strCategory = iter->strCommonType.substr(0, sublen);
		STUALLCATEGORYS categorys;
		categorys.strCommonType = strCategory;
		ListAppend(List, categorys);
	}

#if DEBUG_PRINT
	ListForEachElem(List);
#endif

	return 0;
}

int read_lines(wchar_t *pwszFilesPath)
{
	enumerateFilesList_t::iterator iter;
	for (iter = enumerateFilesList.begin(); iter != enumerateFilesList.end(); iter++)
	{
		std::wstring wstrFilesPath = DIRECTORYPATH;
		wstrFilesPath += L"\\";
		std::wstring wstr = Acsi2WideByteEx(iter->strCommonType);
		wstrFilesPath += wstr;
		STUENUMERATELINES linsInfo;
		enumerateLinesList_t eLinesList;
		ReadContentFromFiles(eLinesList, linsInfo, wstrFilesPath.c_str());
		int sublen = iter->strCommonType.find(".");
		std::string strCategory = iter->strCommonType.substr(0, sublen);

		gMapCategoryAllLines.insert(map_category_all_lines::value_type(strCategory, eLinesList));

	}

#if DEBUG_PRINT
	printf("verify category all lines map:\n");
	map_category_all_lines::iterator itMap;
	enumerateLinesList_t::iterator iterList;
	itMap = gMapCategoryAllLines.find("Dutch");
	if (itMap != gMapCategoryAllLines.end())
	{
		for (iterList = itMap->second.begin(); iterList != itMap->second.end(); iterList++)
		{
			std::cout << iterList->strCommonType << std::endl;
		}
	}
#endif
}

template<typename T>
std::string randomChoice(T l)
{
	int nlMaxSize = l.size();
	int nRandElem = rand() % nlMaxSize;

	std::string strElem = FindListElem(l, nRandElem);

#if DEBUG_PRINT
	cout << "nRandElem : " << nRandElem << " strElem: " << strElem << endl;
#endif

	return strElem;
}

std::string SelectElemFromMap(map_category_all_lines Map, std::string strCategory)
{
	bool bRes = false;
	map_category_all_lines::iterator	itMap;
	std::string strRandLine;

	itMap = gMapCategoryAllLines.find(strCategory);
	if (itMap != gMapCategoryAllLines.end())
	{
		strRandLine = randomChoice(itMap->second);
	}

	return strRandLine;
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> randomTrainingPair()
{
	std::string	strRandCategory, strRandLine;
	strRandCategory = randomChoice(gAllCategorysList);
	strRandLine = SelectElemFromMap(gMapCategoryAllLines, strRandCategory);

	torch::Tensor category_tensor = categoryToTensor(strRandCategory.c_str());
	torch::Tensor line_tensor = lineToTensor(strRandLine.c_str());
	torch::Tensor target_tensor = targetToTensor(strRandLine.c_str());

	return { category_tensor , line_tensor,  target_tensor };
}

3.模型

struct RNN : public torch::nn::Module
{
	RNN(int category_size, int input_size, int hidden_size, int output_size)
	{
		m_category_size = category_size;
		m_input_size = input_size;
		m_hidden_size = hidden_size;
		m_output_size = output_size;

		i2h = register_module("i2h", torch::nn::Linear(category_size + input_size + hidden_size, hidden_size));
		i2o = register_module("i2o", torch::nn::Linear(category_size + input_size + hidden_size, output_size));
		o2o = register_module("o2o", torch::nn::Linear(output_size + hidden_size, output_size));
		softmax = register_module("softmax", torch::nn::LogSoftmax(torch::nn::LogSoftmaxOptions(1)));
	}
	~RNN()
	{

	}

	std::tuple<torch::Tensor, torch::Tensor> forward(torch::Tensor category, torch::Tensor input, torch::Tensor hidden)
	{
		input_combined = torch::cat({ category, input, hidden }, 1);
		m_hidden = i2h(input_combined);
		output = i2o(input_combined);
		output_combined = torch::cat({ m_hidden, output }, 1);
		output = o2o(output_combined);
		/*output = softmax(output);*/
		return { output, m_hidden };
	}

	torch::Tensor initHidden()
	{
		return torch::zeros({ 1, m_hidden_size });
	}

	int m_category_size = 0;
	int m_input_size = 0;
	int m_hidden_size = 0;
	int m_output_size = 0;

	torch::nn::Linear i2h{ nullptr }, i2o{ nullptr }, o2o{ nullptr };
	torch::nn::LogSoftmax softmax = nullptr;
	torch::Tensor input_combined, output, output_combined, m_hidden;
};

4.训练

std::tuple<torch::Tensor, torch::Tensor> train(RNN& rnn, torch::Device device,
												torch::Tensor category_tensor, torch::Tensor input_line_tensor, torch::Tensor target_line_tensor,
												torch::optim::Optimizer& optimizer, torch::nn::CrossEntropyLoss& criterion)
{
	rnn.to(device);
	rnn.train(true);

	torch::Tensor hidden = rnn.initHidden();
	optimizer.zero_grad();

	torch::Tensor loss = torch::zeros({ 1 }, torch::kFloat64);
	cout << "set_requires_grad : " << loss.grad_fn() << endl;
	int nline_tensor_size = 0;
	nline_tensor_size = input_line_tensor.sizes()[0];

	torch::Tensor output;

	for (int i = 0; i < nline_tensor_size; i++)
	{
		std::tuple<torch::Tensor, torch::Tensor> ret = rnn.forward(category_tensor, input_line_tensor[i], hidden);
		output = std::get<0>(ret);
		hidden = std::get<1>(ret);

		torch::Tensor temp = torch::zeros({ 1 }, torch::kLong);
		temp[0] = target_line_tensor[i].item().toLong();

		try
		{
			loss += criterion(output, temp);
		}
		catch (const c10::Error& e)
		{
			std::cout << "criterion an error has occured : " << e.msg() << std::endl;
		}
	}
	loss.backward();
	optimizer.step();

	return { output, loss};
}

5.推理生成

std::string generate_one(std::string category, std::string start_char, float temperature, RNN& rnn, torch::Device device)
{
	rnn.to(device);
	rnn.train(false);

	torch::Tensor hidden = rnn.initHidden();

	torch::Tensor category_tensor = categoryToTensor(category.c_str());
	torch::Tensor line_tensor = lineToTensor(start_char.c_str());

	torch::Tensor output;

	std::string output_str = start_char;

	for (int i = 0; i < gmax_length; i++)
	{
		std::tuple<torch::Tensor, torch::Tensor> ret = rnn.forward(category_tensor, line_tensor[0], hidden);
		output = std::get<0>(ret);
		hidden = std::get<1>(ret);

		// Sample as a multinomial distribution
		torch::Tensor output_dist = output.data().view(-1).div(temperature).exp();
		torch::Tensor top_i = torch::multinomial(output_dist, 1)[0];

		//Stop at EOS, or add to output_str
		if (top_i.item().toInt() == EOS)
			break;
		else
		{
			char sz = gszAll_Letters[top_i.item().toInt()];
			output_str += sz;
			line_tensor = lineToTensor(&sz);
		}
	}

	return output_str;
}

void generate(std::string strCategory, std::string start_chars, RNN& rnn, torch::Device device)
{
	int nStart_Chars_len = start_chars.length();

	printf("\n%s : \n\n", strCategory.c_str());
	for (int i = 0; i < nStart_Chars_len; i ++)
	{
		std::string start_char = start_chars.substr(i, 1);
		printf( "%s\n", generate_one(strCategory, start_char, 0.5, rnn, device));
		//printf("start_char : %s\n", start_char.c_str());
	}
}

总结

下一篇将是官网教程中的莎士比亚文章训练和生成同类型段落。

  • 16
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值