libtorch-char-rnn-generation
libtorch-UNET Carvana车辆轮廓分割
libtorch-VGG cifar10分类
libtorch-FCN CamVid语义分割
libtorch-RDN DIV2K超分重建
libtorch-char-rnn-classification libtorch官网系列教程
libtorch-char-rnn-generation libtorch官网系列教程
libtorch-char-rnn-shakespeare libtorch官网系列教程
libtorch-minst libtorch官网系列教程
————————————————
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
原文链接:https://blog.csdn.net/barcaFC/article/details/140090906
前言
延续pytorch-rnn官网教程,字符级生成,无论是后来的transformer或者大模型chat-gpt都是由最小的rnn模型构建和演变而成,其中最为核心的梯度下降更是贯穿所有现代深度学习的框架。
libtorch C++ libtorch-char-rnn-generation
从数据集准备、训练、推理 全部由libtorch c++完成
运行环境:
操作系统:windows 64位
opencv3.4.1
libtorch 任意版本 64位
visual stdio 2017 64位编译
相关数据集 names
参考论文 https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutorial.html
————————————————
一、rnn-char-generation是什么?
通过某种特性推断属于哪个种类,例如本文中通过国籍推断人名或者通过人名推断国籍。同时这也是生成文章和对话的基础。
二、使用步骤
1.参数设置
代码如下(示例):
int main(int argc, char **argv)
{
if (torch::cuda::is_available())
{
printf("torch::cuda::is_available\n");
device_type = torch::kCUDA;
}
else
{
printf("cpu is_available\n");
device_type = torch::kCPU;
}
torch::Device device(device_type);
#if _TRAIN_STAGE_ //训练
//Preparing for Training
srand((unsigned)time(NULL));
//数据预处理
read_categorys(DIRECTORYPATH, gAllCategorysList);
read_lines(DIRECTORYPATH);
gn_categories = GetListElemCount(gAllCategorysList);
RNN rnn(gn_categories, gn_letters, hidden_size, gn_letters);
auto optimizer = torch::optim::Adam(rnn.parameters(), torch::optim::AdamOptions(learning_rate));
auto criterion = torch::nn::CrossEntropyLoss();
for (int epoch = 0; epoch < n_epochs; epoch++)
{
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> ret = randomTrainingPair();
torch::Tensor category_tensor = std::get<0>(ret);
torch::Tensor line_tensor = std::get<1>(ret);
torch::Tensor target_tensor = std::get<2>(ret);
int nline_tensor_size = 0;
nline_tensor_size = line_tensor.sizes()[0];
std::tuple<torch::Tensor, torch::Tensor> retTran = train(rnn, device, category_tensor, line_tensor, target_tensor, optimizer, criterion);
torch::Tensor output = std::get<0>(retTran);
torch::Tensor loss = std::get<1>(retTran);
loss_avg += (loss.item().toFloat() / nline_tensor_size);
//Print epoch number, loss, name and guess
if (0 == (epoch % print_every))
{
printf("(%d %d%%) loss: %.4f\n", epoch, epoch / n_epochs * 100, (loss.item().toFloat() / nline_tensor_size));
loss_avg = 0;
}
}
printf("Finish training!\n");
torch::serialize::OutputArchive archive;
rnn.save(archive);
archive.save_to("..\\retModel\\char-rnn-generation-c++.pt");
printf("Save the training result to ..\\char-rnn-generation-c++.pt.\n");
#else //推理
std::string strCategory = "Chinese";
std::string strStart_Chars = "CHI";
//数据预处理
read_categorys(DIRECTORYPATH, gAllCategorysList);
read_lines(DIRECTORYPATH);
gn_categories = GetListElemCount(gAllCategorysList);
RNN rnn(gn_categories, gn_letters, hidden_size, gn_letters);
torch::serialize::InputArchive archive;
archive.load_from("..\\retModel\\char-rnn-generation-c++.pt");
rnn.load(archive);
generate(strCategory, strStart_Chars, rnn, device);
#endif
return 0;
}
2.数据处理
代码如下(示例):
//Preparing the Data
char gszAll_Letters[] = { 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f',
'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v',
'w', 'x', 'y', 'z', ' ', '.', ',', ';', '\'', '-' };
int gn_letters = sizeof(gszAll_Letters) + 1;
int EOS = gn_letters - 1;
int letterToIndex(char szLetter)
{
int seq = 0;
for (seq = 0; seq < sizeof(gszAll_Letters); seq++)
{
if (gszAll_Letters[seq] == szLetter)
break;
}
if (seq == sizeof(gszAll_Letters))
return -1;
return seq;
}
/*作者:barcaFC */
/*日期:2022-11-1 */
/*函数:lineToTensor */
/*功能:将line以one-hot的方式转换进tensor */
/*<line_length x 1 x n_letters> */
/*输入:line */
/*输出:line对应的one-hot tensor */
torch::Tensor lineToTensor(const char* szLine)
{
LONGLONG lLineLen = strlen(szLine);
torch::Tensor tensor = torch::zeros({ lLineLen, 1, gn_letters });
for (int i = 0; i < lLineLen; i++)
{
tensor[i][0][letterToIndex(szLine[i])] = 1;
}
#if DEBUG_PRINT
cout << "one-hot line tensor size:\n" << tensor.data().sizes() << endl;
cout << "one-hot line tensor:\n" << tensor.data() << endl;
#endif
return tensor;
}
torch::Tensor categoryToTensor(const char* szCategory)
{
std::string strCategory = szCategory;
int li = FindListElemSeq(gAllCategorysList, strCategory);
torch::Tensor tensor = torch::zeros({ 1, gn_categories });
tensor[0][li] = 1;
return tensor;
}
torch::Tensor targetToTensor(const char* szTarget)
{
LONGLONG lLineLen = strlen(szTarget);
torch::Tensor tensor = torch::zeros({ lLineLen }, torch::kLong);
for (int i = 1; i < lLineLen; i++)
{
tensor[i - 1] = letterToIndex(szTarget[i]);
}
tensor[lLineLen-1] = EOS;
return tensor;
}
int read_categorys(wchar_t *pwszDirectoryPath, allCategorysList_t &List)
{
EnumerateFileInDirectory(pwszDirectoryPath);
enumerateFilesList_t::iterator iter;
for (iter = enumerateFilesList.begin(); iter != enumerateFilesList.end(); iter++)
{
int sublen = iter->strCommonType.find(".");
std::string strCategory = iter->strCommonType.substr(0, sublen);
STUALLCATEGORYS categorys;
categorys.strCommonType = strCategory;
ListAppend(List, categorys);
}
#if DEBUG_PRINT
ListForEachElem(List);
#endif
return 0;
}
int read_lines(wchar_t *pwszFilesPath)
{
enumerateFilesList_t::iterator iter;
for (iter = enumerateFilesList.begin(); iter != enumerateFilesList.end(); iter++)
{
std::wstring wstrFilesPath = DIRECTORYPATH;
wstrFilesPath += L"\\";
std::wstring wstr = Acsi2WideByteEx(iter->strCommonType);
wstrFilesPath += wstr;
STUENUMERATELINES linsInfo;
enumerateLinesList_t eLinesList;
ReadContentFromFiles(eLinesList, linsInfo, wstrFilesPath.c_str());
int sublen = iter->strCommonType.find(".");
std::string strCategory = iter->strCommonType.substr(0, sublen);
gMapCategoryAllLines.insert(map_category_all_lines::value_type(strCategory, eLinesList));
}
#if DEBUG_PRINT
printf("verify category all lines map:\n");
map_category_all_lines::iterator itMap;
enumerateLinesList_t::iterator iterList;
itMap = gMapCategoryAllLines.find("Dutch");
if (itMap != gMapCategoryAllLines.end())
{
for (iterList = itMap->second.begin(); iterList != itMap->second.end(); iterList++)
{
std::cout << iterList->strCommonType << std::endl;
}
}
#endif
}
template<typename T>
std::string randomChoice(T l)
{
int nlMaxSize = l.size();
int nRandElem = rand() % nlMaxSize;
std::string strElem = FindListElem(l, nRandElem);
#if DEBUG_PRINT
cout << "nRandElem : " << nRandElem << " strElem: " << strElem << endl;
#endif
return strElem;
}
std::string SelectElemFromMap(map_category_all_lines Map, std::string strCategory)
{
bool bRes = false;
map_category_all_lines::iterator itMap;
std::string strRandLine;
itMap = gMapCategoryAllLines.find(strCategory);
if (itMap != gMapCategoryAllLines.end())
{
strRandLine = randomChoice(itMap->second);
}
return strRandLine;
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> randomTrainingPair()
{
std::string strRandCategory, strRandLine;
strRandCategory = randomChoice(gAllCategorysList);
strRandLine = SelectElemFromMap(gMapCategoryAllLines, strRandCategory);
torch::Tensor category_tensor = categoryToTensor(strRandCategory.c_str());
torch::Tensor line_tensor = lineToTensor(strRandLine.c_str());
torch::Tensor target_tensor = targetToTensor(strRandLine.c_str());
return { category_tensor , line_tensor, target_tensor };
}
3.模型
struct RNN : public torch::nn::Module
{
RNN(int category_size, int input_size, int hidden_size, int output_size)
{
m_category_size = category_size;
m_input_size = input_size;
m_hidden_size = hidden_size;
m_output_size = output_size;
i2h = register_module("i2h", torch::nn::Linear(category_size + input_size + hidden_size, hidden_size));
i2o = register_module("i2o", torch::nn::Linear(category_size + input_size + hidden_size, output_size));
o2o = register_module("o2o", torch::nn::Linear(output_size + hidden_size, output_size));
softmax = register_module("softmax", torch::nn::LogSoftmax(torch::nn::LogSoftmaxOptions(1)));
}
~RNN()
{
}
std::tuple<torch::Tensor, torch::Tensor> forward(torch::Tensor category, torch::Tensor input, torch::Tensor hidden)
{
input_combined = torch::cat({ category, input, hidden }, 1);
m_hidden = i2h(input_combined);
output = i2o(input_combined);
output_combined = torch::cat({ m_hidden, output }, 1);
output = o2o(output_combined);
/*output = softmax(output);*/
return { output, m_hidden };
}
torch::Tensor initHidden()
{
return torch::zeros({ 1, m_hidden_size });
}
int m_category_size = 0;
int m_input_size = 0;
int m_hidden_size = 0;
int m_output_size = 0;
torch::nn::Linear i2h{ nullptr }, i2o{ nullptr }, o2o{ nullptr };
torch::nn::LogSoftmax softmax = nullptr;
torch::Tensor input_combined, output, output_combined, m_hidden;
};
4.训练
std::tuple<torch::Tensor, torch::Tensor> train(RNN& rnn, torch::Device device,
torch::Tensor category_tensor, torch::Tensor input_line_tensor, torch::Tensor target_line_tensor,
torch::optim::Optimizer& optimizer, torch::nn::CrossEntropyLoss& criterion)
{
rnn.to(device);
rnn.train(true);
torch::Tensor hidden = rnn.initHidden();
optimizer.zero_grad();
torch::Tensor loss = torch::zeros({ 1 }, torch::kFloat64);
cout << "set_requires_grad : " << loss.grad_fn() << endl;
int nline_tensor_size = 0;
nline_tensor_size = input_line_tensor.sizes()[0];
torch::Tensor output;
for (int i = 0; i < nline_tensor_size; i++)
{
std::tuple<torch::Tensor, torch::Tensor> ret = rnn.forward(category_tensor, input_line_tensor[i], hidden);
output = std::get<0>(ret);
hidden = std::get<1>(ret);
torch::Tensor temp = torch::zeros({ 1 }, torch::kLong);
temp[0] = target_line_tensor[i].item().toLong();
try
{
loss += criterion(output, temp);
}
catch (const c10::Error& e)
{
std::cout << "criterion an error has occured : " << e.msg() << std::endl;
}
}
loss.backward();
optimizer.step();
return { output, loss};
}
5.推理生成
std::string generate_one(std::string category, std::string start_char, float temperature, RNN& rnn, torch::Device device)
{
rnn.to(device);
rnn.train(false);
torch::Tensor hidden = rnn.initHidden();
torch::Tensor category_tensor = categoryToTensor(category.c_str());
torch::Tensor line_tensor = lineToTensor(start_char.c_str());
torch::Tensor output;
std::string output_str = start_char;
for (int i = 0; i < gmax_length; i++)
{
std::tuple<torch::Tensor, torch::Tensor> ret = rnn.forward(category_tensor, line_tensor[0], hidden);
output = std::get<0>(ret);
hidden = std::get<1>(ret);
// Sample as a multinomial distribution
torch::Tensor output_dist = output.data().view(-1).div(temperature).exp();
torch::Tensor top_i = torch::multinomial(output_dist, 1)[0];
//Stop at EOS, or add to output_str
if (top_i.item().toInt() == EOS)
break;
else
{
char sz = gszAll_Letters[top_i.item().toInt()];
output_str += sz;
line_tensor = lineToTensor(&sz);
}
}
return output_str;
}
void generate(std::string strCategory, std::string start_chars, RNN& rnn, torch::Device device)
{
int nStart_Chars_len = start_chars.length();
printf("\n%s : \n\n", strCategory.c_str());
for (int i = 0; i < nStart_Chars_len; i ++)
{
std::string start_char = start_chars.substr(i, 1);
printf( "%s\n", generate_one(strCategory, start_char, 0.5, rnn, device));
//printf("start_char : %s\n", start_char.c_str());
}
}
总结
下一篇将是官网教程中的莎士比亚文章训练和生成同类型段落。