libtorch-char-rnn-shakespeare
libtorch-UNET Carvana车辆轮廓分割
libtorch-VGG cifar10分类
libtorch-FCN CamVid语义分割
libtorch-RDN DIV2K超分重建
libtorch-char-rnn-classification libtorch官网系列教程
libtorch-char-rnn-generation libtorch官网系列教程
libtorch-char-rnn-shakespeare libtorch官网系列教程
libtorch-minst libtorch官网系列教程
————————————————
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
原文链接:https://blog.csdn.net/barcaFC/article/details/140116269
前言
延续pytorch-rnn官网教程,本文为libtorch教程上的文章内容生成,使用的是莎士比亚文章内容进行学习和生成,可以自由修改成看图写话或者小短文的生成。
今年我国在内容生成上的专利申请数全球第一。
内容生成可以生成数据、文字、图片、视频、音乐等等。然而,再重复一次,解决问题不是要全球领先,而是根据实际问题选用合适的模型进行改动。
libtorch C++ libtorch-char-rnn-shakespeare
从数据集准备、训练、推理 全部由libtorch c++完成
运行环境:
操作系统:windows 64位
opencv3.4.1
libtorch 任意版本 64位
visual stdio 2017 64位编译
关联参考:
https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutorial.html
https://github.com/spro/practical-pytorch/tree/master/char-rnn-generation
RNN_KERNEL_VERSION 变为2.0版本,使用encoder-decoder结构、WordEmbedding、rnn-GRU
因为生成对话需要较长距离的上下文内容
系列中data.hpp文件跟随官网改变成helpers.hpp
read_file函数读取原始文件内容由前两版ReadContentFromFiles函数修改而来
————————————————
一、rnn-shakespeare是什么?
和基于字符级生成不同,本文使用WordEmbedding,对词进行了编码使得能够词生成,而且结构上具有了后续seq2seq的核心模型encoder-decoder。而模型演示的encoder-decoder也很简单但实现的功能确是很强大的。
无论是RDN进行图片重建或者是RNN内容生成,均具有前后关联(跳跃链接、隐藏层)以及encoder-decoder结构。
二、使用步骤
1.设置参数
int main(int argc, char **argv)
{
if (torch::cuda::is_available())
{
printf("torch::cuda::is_available\n");
device_type = torch::kCUDA;
}
else
{
printf("cpu is_available\n");
device_type = torch::kCPU;
}
torch::Device device(device_type);
#if _TRAIN_STAGE_ //训练
srand((unsigned)time(NULL));
read_file(L"..\\data\\shakespeare.txt");
RNN decoder(n_characters, hidden_size, n_characters, n_layers);
auto decoder_optimizer = torch::optim::Adam(decoder.parameters(), torch::optim::AdamOptions(learning_rate));
auto criterion = torch::nn::CrossEntropyLoss();
printf("Training for %d epochs...", n_epochs);
//float loss_avg = 0;
for (int epoch = 0; epoch < n_epochs; epoch++)
{
std::tuple<torch::Tensor, torch::Tensor> retRandomSet = random_training_set(chunk_len);
torch::Tensor inp = std::get<0>(retRandomSet);
torch::Tensor target = std::get<1>(retRandomSet);
//printf("random_training_set : %d %d\n", inp.sizes(), target.sizes());
//cout << endl << "inp.sizes()::" << inp.sizes() << "target.sizes()::" << target.sizes() << endl;
torch::Tensor loss = train(decoder, device, inp, target, decoder_optimizer, criterion);
//printf("train : %.4f\n", loss.item().toFloat());
//loss_avg += (loss.item().toFloat() / chunk_len);
//Print epoch number, loss, name and guess
if (0 == (epoch % print_every))
{
printf("(%d %d%%) loss: %.4f\n", epoch, epoch / n_epochs, (loss.item().toFloat() / chunk_len));
//loss_avg = 0;
}
}
printf("Finish training!\n");
torch::serialize::OutputArchive archive;
decoder.save(archive);
archive.save_to("..\\retModel\\char-rnn-shakespeare-c++.pt");
printf("Save the training result to ..\\char-rnn-shakespeare-c++.pt.\n");
#else //推理
RNN decoder(n_characters, hidden_size, n_characters, n_layers);
torch::serialize::InputArchive archive;
archive.load_from("..\\retModel\\char-rnn-shakespeare-c++.pt");
decoder.load(archive);
std::string prime_str = "Anger";
cout << endl << generate(prime_str, decoder, device, 200, 0.8) << endl;
#endif;
return 0;
}
2.网络结构
struct RNN : public torch::nn::Module
{
//n_characters, hidden_size, n_characters, n_layers
RNN(int input_size, int hidden_size, int output_size, int n_layers)
{
m_input_size = input_size;
m_hidden_size = hidden_size;
m_output_size = output_size;
m_n_layers = n_layers;
encoder = register_module("encoder", torch::nn::Embedding(input_size, hidden_size));
gru = register_module("gru", torch::nn::GRU(torch::nn::GRUOptions(hidden_size, hidden_size).num_layers(n_layers)));
decoder = register_module("decoder", torch::nn::Linear(hidden_size, output_size));
}
~RNN()
{
}
std::tuple<torch::Tensor, torch::Tensor> forward(torch::Tensor input, torch::Tensor hidden)
{
m_input = encoder(input.view({ 1, -1 }));
std::tuple<torch::Tensor, torch::Tensor> retgru = gru(m_input.view({ 1, 1, -1 }), hidden);
m_output = std::get<0>(retgru);
m_hidden = std::get<1>(retgru);
m_output = decoder(m_output.view({ 1, -1 }));
return { m_output, m_hidden };
}
torch::Tensor initHidden()
{
return torch::zeros({ m_n_layers, 1, m_hidden_size });
}
int m_n_layers = 0;
int m_input_size = 0;
int m_hidden_size = 0;
int m_output_size = 0;
torch::nn::Embedding encoder{ nullptr };
torch::nn::GRU gru{ nullptr };
torch::nn::Linear decoder{ nullptr };
torch::Tensor m_input, m_output, m_hidden;
};
3.训练
/*作者:barcaFC */
/*日期:2022-11-10 */
/*函数:random_training_set */
/*功能:inp从0-倒数第二字符 target从1到最后一个字符 */
/*输入:chunk_len要读取多少字符 */
/*输出:输出inp和target的tensor 送入wordEmbedding */
std::tuple<torch::Tensor, torch::Tensor> random_training_set(int chunk_len)
{
torch::Tensor inp, target;
int start_index = rand() % gfilelen;
int end_index = start_index + chunk_len + 1;
std::string chunk = gfile.substr(start_index, end_index);
inp = char_tensor(chunk.substr(start_index, chunk_len));
target = char_tensor(chunk.substr(start_index + 1, chunk_len + 1));
return { inp ,target };
}
torch::Tensor train(RNN& decoder, torch::Device device,
torch::Tensor inp, torch::Tensor target,
torch::optim::Optimizer& decoder_optimizer, torch::nn::CrossEntropyLoss& criterion)
{
decoder.to(device);
decoder.train(true);
torch::Tensor hidden = decoder.initHidden();
decoder_optimizer.zero_grad();
torch::Tensor loss = torch::zeros({ 1 }, torch::kFloat64);
torch::Tensor output;
for (int c = 0; c < chunk_len; c++)
{
//cout << "inp[c].data.size : " << endl;
//cout << inp[c].data().sizes() << endl;
std::tuple<torch::Tensor, torch::Tensor> retTrain = decoder.forward(inp[c], hidden);
output = std::get<0>(retTrain);
hidden = std::get<1>(retTrain);
try
{
torch::Tensor temp = torch::zeros({ 1 }, torch::kLong);
temp[0] = target[c].item().toLong();
loss += criterion(output, temp);
}
catch (const c10::Error& e)
{
std::cout << "criterion an error has occured : " << e.msg() << std::endl;
return loss;
}
}
loss.backward();
decoder_optimizer.step();
return loss;
}
4.生成
std::string generate(std::string prime_str, RNN& decoder, torch::Device device, int predict_len, float temperature)
{
decoder.to(device);
decoder.train(false);
torch::Tensor hidden = decoder.initHidden();
torch::Tensor prime_input = char_tensor(prime_str);
std::string predicted = prime_str;
//Use priming string to "build up" hidden state
for (int p = 0; p < prime_str.length(); p++)
{
std::tuple<torch::Tensor, torch::Tensor> retHidden = decoder.forward(prime_input[p], hidden);
hidden = std::get<1>(retHidden);
}
torch::Tensor inp = prime_input[prime_str.length() - 1];
torch::Tensor output;
for (int c = 0; c < predict_len; c++)
{
std::tuple<torch::Tensor, torch::Tensor> retDecoder = decoder.forward(inp, hidden);
output = std::get<0>(retDecoder);
hidden = std::get<1>(retDecoder);
//Sample from the network as a multinomial distribution
torch::Tensor output_dist = output.data().view(-1).div(temperature).exp();
torch::Tensor top_i = torch::multinomial(output_dist, 1)[0];
//Add predicted character to string and use as next input
char predicted_char = all_characters.at(top_i.item().toInt());
predicted += predicted_char;
std::string strINP;
strINP = predicted_char;
inp = char_tensor(strINP);
}
return predicted;
}
5.辅助函数
std::string all_characters = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+, -./:;<=>?@[\]^_`{|}~";
int n_characters = all_characters.length();
int gfilelen = 0;
std::string gfile;
BOOL read_file(const wchar_t * pwszFilePath)
{
ifstream fin;
std::string str;
fin.open(pwszFilePath, ios::in);
fin.seekg(0, ios::end);
gfilelen = fin.tellg();
fin.seekg(0, ios::beg);
if (!fin.is_open())
{
cout << "无法找到这个文件!" << endl;
return FALSE;
}
while (getline(fin, str))
{
gfile += str;
}
fin.close();
return TRUE;
}
torch::Tensor char_tensor(std::string str)
{
LONGLONG len = str.length();
//cout << "str : " << str << endl;
torch::Tensor tensor = torch::zeros({ len }, torch::kLong);
for (int c = 0; c < len; c++)
tensor[c] = (LONG)all_characters.find(str[c]);
//cout << "char_tensor sizes===>:" << endl;
//cout << tensor.sizes() << endl;
//cout << tensor << endl;
return tensor;
}
总结
内容生成(AIGC)是目前炙手可热的AI分支,内容生成不仅仅是能够生成训练内容,也是一种迭代逼近某种解析解(具有无限解空间的最优解)的方法。将内容生成和其他提取类模型相结合,能够辅助完成非常多之前需要大量人力物力的工作。