libtorch-UNET Carvana车辆轮廓分割
libtorch-UNET Carvana车辆轮廓分割
libtorch-VGG cifar10分类
libtorch-FCN CamVid语义分割
libtorch-RDN DIV2K超分重建
libtorch-char-rnn-classification libtorch官网系列教程
libtorch-char-rnn-generation libtorch官网系列教程
libtorch-char-rnn-shakespeare libtorch官网系列教程
libtorch-minst libtorch官网系列教程
————————————————
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
原文链接:https://blog.csdn.net/barcaFC/article/details/139865461
前言
libtorch C++ pytorch unet Carvana车辆轮廓分割
从数据集准备、训练、推理 全部由libtorch c++完成
运行环境:
操作系统:windows 64位
opencv3.4.1
libtorch 任意版本 64位
visual stdio 2017 64位编译
相关数据集 Carvana
参考论文 https://paperswithcode.com/paper/u-net-convolutional-networks-for-biomedical
一、UNET是什么?
UNET通过U型网络结构,在像素级上通过高低维度对称特征训练融合推理每个像素分类。U型网络左右两条通道执行常规特征提取和训练推理,两通道之间同时进行对应特征的融合计算,使得高维特征和低维特征的有效融合提取,极大提高有限样本特征提取效果,从而使得较少的样本能够达到高效的训练效果和推理速度。
二、使用步骤
1.设置参数
代码如下:
args::ArgumentParser parser("Train the UNet on images and target masks.", "This goes after the options.");
args::HelpFlag help(parser, "help", "Display this help menu", { 'h', "help" });
args::ValueFlag<int> n_class(parser, "n_class", "Pixel classification", { "n_class" }, 1);
args::ValueFlag<int> n_channels(parser, "n_channels", "RGB images", { "n_channels" }, 3);
args::ValueFlag<bool> isbilinear(parser, "bilinear", "if net.bilinear else Transposed conv", { "bilinear" }, 1);
args::ValueFlag<int> epochs(parser, "epochs", "number of epochs to train", { "epochs" }, 5);
args::ValueFlag<int> batch_size(parser, "batch_size", "train patch size", { "batch_size" }, 1);
args::ValueFlag<double> lr(parser, "lr", "learning rate", { "lr" }, 1e-4);
args::ValueFlag<float> scale(parser, "scale", "Downscaling factor of the images", { "scale" }, 0.5);
args::ValueFlag<std::string> str_imgs_train_dir(parser, "imgs_train_dir", "train dataset directory", { "imgs_train_dir" }, "..\\Carvana\\imgs_train");
args::ValueFlag<std::string> str_dir_train_mask(parser, "dir_train_mask", "label dataset directory", { "dir_train_mask" }, "..\\Carvana\\masks_train");
args::ValueFlag<std::string> str_imgs_valid_dir(parser, "imgs_valid_dir", "valid dataset directory", { "imgs_valid_dir" }, "..\\Carvana\\imgs_valid");
args::ValueFlag<std::string> str_dir_valid_mask(parser, "dir_valid_mask", "label valid dataset directory", { "dir_valid_mask" }, "..\\Carvana\\masks_valid");
try
{
parser.ParseCLI(argc, argv);
}
catch (args::Help)
{
std::cout << parser;
return 0;
}
catch (args::ParseError e)
{
std::cerr << e.what() << std::endl;
std::cerr << parser;
return 1;
}
catch (args::ValidationError e)
{
std::cerr << e.what() << std::endl;
std::cerr << parser;
return 1;
}
int Epochs = (int)args::get(epochs);
int batch = (int)args::get(batch_size);
double learning_rate = (double)args::get(lr);
float scale_factor = (float)args::get(scale);
int nclass = (int)args::get(n_class);
int channels = (int)args::get(n_channels);
bool bilinear = (bool)args::get(isbilinear);
string imgs_dir_train = (string)args::get(str_imgs_train_dir);
string dir_mask_train = (string)args::get(str_dir_train_mask);
string imgs_dir_valid = (string)args::get(str_imgs_valid_dir);
string dir_mask_valid = (string)args::get(str_dir_valid_mask);
cout << "Epochs : " << Epochs << endl;
cout << "batch : " << batch << endl;
cout << "learning_rate : " << learning_rate << endl;
cout << "scale : " << scale_factor << endl;
cout << "n_channels : " << channels << endl;
cout << "n_class : " << nclass << endl;
cout << "bilinear : " << bilinear << endl;
cout << "imgs_train : " << imgs_dir_train << endl;
cout << "mask_train : " << dir_mask_train << endl;
cout << "imgs_valid : " << imgs_dir_valid << endl;
cout << "mask_valid : " << dir_mask_valid << endl;
2.读入数据
代码如下:
std::vector<double> norm_mean = { 0.485, 0.456, 0.406 };
std::vector<double> norm_std = { 0.229, 0.224, 0.225 };
auto train_dataset = CarvanaDataset(imgs_dir_train, dir_mask_train, scale_factor)/*.map(torch::data::transforms::Normalize<>(norm_mean, norm_std))*/;
auto train_loader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(
std::move(train_dataset), batch);
auto valid_dataset = CarvanaDataset(imgs_dir_valid, dir_mask_valid, scale_factor);
auto valid_loader = torch::data::make_data_loader<torch::data::samplers::RandomSampler>(
std::move(valid_dataset), batch);
3.UNET网络
代码如下:
struct UNET : public torch::nn::Module
{
UNET(int n_channels = 3, int n_classes = 1, bool bilinear = true):n_channels(n_channels), n_classes(n_classes), bilinear_(bilinear)
{
inc = register_module("inc", std::make_shared<DoubleConv>(n_channels, 64));
down1 = register_module("down1", std::make_shared<Down>(64, 128));
down2 = register_module("down2", std::make_shared<Down>(128, 256));
down3 = register_module("down3", std::make_shared<Down>(256, 512));
if (bilinear_) factor = 2; else factor = 1;
down4 = register_module("down4", std::make_shared<Down>(512, 1024 / factor));
up1 = register_module<Up>("up1", std::make_shared<Up>(1024, 512 / factor, bilinear_));
up2 = register_module<Up>("up2", std::make_shared<Up>(512, 256 / factor, bilinear));
up3 = register_module<Up>("up3", std::make_shared<Up>(256, 128 / factor, bilinear));
up4 = register_module<Up>("up4", std::make_shared<Up>(128, 64, bilinear));
outc = register_module<OutConv>("outc", std::make_shared<OutConv>(64, n_classes));
}
torch::Tensor forward(torch::Tensor x)
{
x1 = inc(x);
x2 = down1(x1);
x3 = down2(x2);
x4 = down3(x3);
x5 = down4(x4);
x = up1(x5, x4);
x = up2(x, x3);
x = up3(x, x2);
x = up4(x, x1);
logits = outc(x);
return logits;
}
torch::nn::ModuleHolder<DoubleConv> inc{ nullptr };
torch::nn::ModuleHolder<Down> down1{ nullptr }, down2{ nullptr }, down3{ nullptr }, down4{ nullptr };
torch::nn::ModuleHolder<Up> up1{ nullptr }, up2{ nullptr }, up3{ nullptr }, up4{ nullptr };
torch::nn::ModuleHolder<OutConv> outc{ nullptr };
torch::Tensor x1,x2,x3,x4,x5;
torch::Tensor logits;
private:
int n_channels;
int n_classes;
bool bilinear_;
int factor;
};
总结
资源一直上传失败,过后会更新下载地址