struct DCGANGeneratorImpl : nn::Module {DCGANGeneratorImpl(int kNoiseSize):conv1(nn::ConvTranspose2dOptions(kNoiseSize,256,4).bias(false)),batch_norm1(256),conv2(nn::ConvTranspose2dOptions(256,128,3).stride(2).padding(1).bias(false)),batch_norm2(128),conv3(nn::ConvTranspose2dOptions(128,64,4).stride(2).padding(1).bias(false)),batch_norm3(64),conv4(nn::ConvTranspose2dOptions(64,1,4).stride(2).padding(1).bias(false)){// register_module() is needed if we want to use the parameters() method later onregister_module("conv1", conv1);register_module("conv2", conv2);register_module("conv3", conv3);register_module("conv4", conv4);register_module("batch_norm1", batch_norm1);register_module("batch_norm2", batch_norm2);register_module("batch_norm3", batch_norm3);}
torch::Tensor forward(torch::Tensor x){
x = torch::relu(batch_norm1(conv1(x)));
x = torch::relu(batch_norm2(conv2(x)));
x = torch::relu(batch_norm3(conv3(x)));
x = torch::tanh(conv4(x));return x;}
nn::ConvTranspose2d conv1, conv2, conv3, conv4;
nn::BatchNorm2d batch_norm1, batch_norm2, batch_norm3;};TORCH_MODULE(DCGANGenerator);
DCGANGenerator generator(kNoiseSize);
PyTorch C++ 官方教程摘要(1) Using the PyTorch C++ Frontend
文章目录0. 前言1. 为什么要用C++2. DCGAN PyTorch C++ 示例2.1. 使用基本流程2.2. 网络结构定义0. 前言PyTorch官方教程中有一些C++相关的内容。今天要学习的主要是 Using The Pytorch C++ Frontend本文主要内容包括:为什么要用C++以DCGAN为例实现功能1. 为什么要用C++其实就是相比Python,C++的优势。C++前端的目标不是替代Python前端,而是补充。Low Latency Syste