SimpleHumanPose代码及原理分析(二)-- data与label前处理

4 篇文章 0 订阅
3 篇文章 0 订阅

SimpleHumanPose代码及原理分析(二)-- data与label前处理

SimpleHumanPose代码及原理分析(一)-- coco keypoints数据集中,介绍了coco数据集,这篇介绍下个人对data与label前处理的见解~
heatmap_data_layer.cpp代码
这个cpp文件就是数据前处理的c++代码,我们分函数进行解析

一、DataLayerSetUp函数

namespace caffe {
    template<typename Dtype>
    HeatmapDataLayer<Dtype>::~HeatmapDataLayer<Dtype>() {
        this->StopInternalThread(); //析构函数重载
    }

    template<typename Dtype>
    void HeatmapDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype> *> &bottom,
                                                 const vector<Blob<Dtype> *> &top) {
        const int new_height = this->layer_param_.heatmap_data_param().new_height(); //原始图片重新resize的长
        const int new_width = this->layer_param_.heatmap_data_param().new_width(); //原始图片重新resize的宽
        const int crop_size = this->layer_param_.heatmap_data_param().crop_size(); //裁剪大小(是指裁剪之后得到的大小,不是裁剪大小)
        const int points_num = this->layer_param_.heatmap_data_param().coordinate_num(); //预测关键点数量
        const bool is_color = this->layer_param_.heatmap_data_param().is_color(); 
        string root_folder = this->layer_param_.heatmap_data_param().root_folder(); //图片存放的根目录
        const int label_size = this->layer_param_.heatmap_data_param().label_size(); //最后生成的heatmap的大小,这里为64

        CHECK(new_height > 0 && new_width > 0) << "Both new_width and new_height must be greater than 0."; //检测resize的长宽的值是否大于0符合要求
        // Read the file with filenames and labels
        CHECK(new_height >= crop_size && new_width >= crop_size) //检测resize的大小是否大于裁剪的大小,即resize的大小必须要大于裁剪的大小
        << "Both new_width and new_height must greater than crop_size.";
        const string &source = this->layer_param_.heatmap_data_param().source(); //存放txt文件的路径
        LOG(INFO) << "Opening file " << source;
        std::ifstream infile(source.c_str()); //打开txt文件
        string line;
        vector<string> line_split;
        while (std::getline(infile, line)) //按行读取txt文件内容
        {
            boost::split(line_split, line, boost::is_any_of(" "), boost::token_compress_on); //按照空格将读取的一行进行切割,存放至line_split中,line_split大致为:<img_dir, first_point_x, first_point_y, ......>
            CHECK_EQ((int) line_split.size() - 1, points_num*2) << "points num is not same with source."; //line_split的size减去一表示的是所有点乘二的数量
            lines_.push_back(line_split);//将切割好的line_split的vector放入lines_中
        }

        CHECK(!lines_.empty()) << "File is empty";

        if (this->layer_param_.heatmap_data_param().shuffle()) //如果进行打乱操作....
        {
            // randomly shuffle data
            LOG(INFO) << "Shuffling data";
            const unsigned int prefetch_rng_seed = caffe_rng_rand(); //返回一个unsigned int类型(非负数)的随机数种子
            prefetch_rng_.reset(new Caffe::RNG(prefetch_rng_seed));
            ShuffleImages(); //调用打乱函数
        } else
        {
            if (this->phase_ == TRAIN && Caffe::solver_rank() > 0 &&
                this->layer_param_.heatmap_data_param().rand_skip() == 0)
            {
                LOG(WARNING) << "Shuffling or skipping recommended for multi-GPU";
            }
        }
        LOG(INFO) << "A total of " << lines_.size() << " images."; //计算训练集中图片的数量

        lines_id_ = 0;
        // Check if we would need to randomly skip a few data points
        if (this->layer_param_.heatmap_data_param().rand_skip())//跳跃取点,个人认为是增加模型的随机性的目的
        {
            unsigned int skip = caffe_rng_rand() %
                                this->layer_param_.heatmap_data_param().rand_skip(); //随机生成跳跃步长
            LOG(INFO) << "Skipping first " << skip << " data points.";
            CHECK_GT(lines_.size(), skip) << "Not enough points to skip"; //判断图片总数目是否大于跳跃步长
            lines_id_ = skip;
        }
        // Read an image, and use it to initialize the top blob.
        cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_][0],
                                          new_height, new_width, is_color); //用opencv读取跳跃点的图片,将其resize到(256,256)
        CHECK(cv_img.data) << "Could not load " << lines_[lines_id_][0];
        // Use data_transformer to infer the expected blob shape from a cv_image.
//  vector<int> top_shape = this->data_transformer_->InferBlobShape(cv_img);
        vector<int> top_shape(4); //构建一个[1,3,new_height,new_width]
        top_shape[0] = 1;
        top_shape[1] = cv_img.channels();
        if (crop_size)
        {
            top_shape[2] = crop_size;
            top_shape[3] = crop_size;
        } else
        {
            top_shape[2] = new_height;
            top_shape[3] = new_width;
        }
        trans_data_tmp_.Reshape(top_shape);
        // Reshape prefetch_data and top[0] according to the batch_size.
        const int batch_size = this->layer_param_.heatmap_data_param().batch_size();
        CHECK_GT(batch_size, 0) << "Positive batch size required";
        top_shape[0] = batch_size;
        top[0]->Reshape(top_shape); //将输出的特征图resize到[batch_size, 3, new_height,new_width]
        for (int i = 0; i < this->prefetch_.size(); ++i)
        {
            this->prefetch_[i]->data_.Reshape(top_shape);
        }
        LOG(INFO) << "output data size: " << top[0]->num() << ","
                  << top[0]->channels() << "," << top[0]->height() << ","
                  << top[0]->width();
        // label
        vector<int> label_shape(4); //下列是标签的处理,同上述data相似
        label_shape[0] = 1;
        label_shape[1] = points_num;
        label_shape[2] = label_size;
        label_shape[3] = label_size;
        trans_label_tmp_.Reshape(label_shape);
        label_shape[0] = batch_size;
        top[1]->Reshape(label_shape);
        for (int i = 0; i < this->prefetch_.size(); ++i)
        {
            this->prefetch_[i]->label_.Reshape(label_shape);
        }
    }

二、load_batch函数

// This function is called on prefetch thread
    template<typename Dtype>
    void HeatmapDataLayer<Dtype>::load_batch(Batch<Dtype> *batch) {
        CPUTimer batch_timer;
        batch_timer.Start();
        double read_time = 0;
        double trans_time = 0;
        CPUTimer timer;
        CHECK(batch->data_.count());
        CHECK(trans_data_tmp_.count());
        HeatmapDataParameter heatmap_data_param = this->layer_param_.heatmap_data_param();
        const int batch_size = heatmap_data_param.batch_size(); //批次大小
        const int new_height = heatmap_data_param.new_height(); //resize长
        const int new_width = heatmap_data_param.new_width(); //resize宽
        const int crop_size = heatmap_data_param.crop_size(); //裁剪之后大小
        const bool is_color = heatmap_data_param.is_color();
        string root_folder = heatmap_data_param.root_folder();

        // Reshape according to the first image of each batch
        // on single input batches allows for inputs of varying dimension.
        cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_][0], new_height, new_width, is_color); //获取每个批次中第一张图片的shape
        vector<int> top_shape(4);
        top_shape[0] = 1;
        top_shape[1] = cv_img.channels();
        if (crop_size)
        {
            top_shape[2] = crop_size;
            top_shape[3] = crop_size;
        } else
        {
            top_shape[2] = new_height;
            top_shape[3] = new_width;
        }

        trans_data_tmp_.Reshape(top_shape);
        top_shape[0] = batch_size;
        batch->data_.Reshape(top_shape);


        Dtype *prefetch_data = batch->data_.mutable_cpu_data(); //将批次图片数据放到cpu上,并标记为可写
        Dtype *prefetch_label = batch->label_.mutable_cpu_data();//将批次图片的标签放到cpu上,并标记为可写

        // datum scales
        const int lines_size = lines_.size();
        for (int item_id = 0; item_id < batch_size; ++item_id) //处理每个批次中每一张图片
        {
            // get a blob
            timer.Start();
            CHECK_GT(lines_size, lines_id_);
            cv::Mat img;
            int cv_read_flag = is_color ? CV_LOAD_IMAGE_COLOR : CV_LOAD_IMAGE_GRAYSCALE;
            string filename = root_folder + lines_[lines_id_][0];
            cv::Mat cv_img_origin = cv::imread(filename, cv_read_flag); //读取批次第一张图片
            if (!cv_img_origin.data)
            {
                LOG(ERROR) << "Could not open or find file " << filename;
            }
            if (new_height > 0 && new_width > 0)
            {
                cv::resize(cv_img_origin, img, cv::Size(new_width, new_height));
            } else
            {
                img = cv_img_origin;
            }
            float img_width_scale = (float) img.cols / cv_img_origin.cols; //计算resize到256与原始图片的宽的比例
            float img_height_scale = (float) img.rows / cv_img_origin.rows;//计算resize到256与原始图片的长的比例

            CHECK(img.data) << "Could not load " << lines_[lines_id_][0];
            read_time += timer.MicroSeconds();
            timer.Start();
            // Apply transformations (mirror, crop...) to the image
            int data_offset = batch->data_.offset(item_id);
            int label_offset = batch->label_.offset(item_id);
            trans_data_tmp_.set_cpu_data(prefetch_data + data_offset);
            trans_label_tmp_.set_cpu_data(prefetch_label + label_offset);
            transform_data_label(&trans_data_tmp_, &trans_label_tmp_, img, img_width_scale, img_height_scale); //图片数据与标签进行数据增强操作

            trans_time += timer.MicroSeconds();
            // go to the next iter
            lines_id_++; //处理同一批次下一张图片
            if (lines_id_ >= lines_size)
            {
                // We have reached the end. Restart from the first.
                DLOG(INFO) << "Restarting data prefetching from start.";
                lines_id_ = 0;
                if (this->layer_param_.heatmap_data_param().shuffle())
                {
                    ShuffleImages(); //当处理完一个批次最后一张图片之后,进行批次打乱,并重新从第一张图片进行处理
                }
            }
        }
        batch_timer.Stop();
        DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";
        DLOG(INFO) << "     Read time: " << read_time / 1000 << " ms.";
        DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms.";
    }

三、transform_data_label函数(核心)

/*针对于每张图片进行数据增强操作*/
template<typename Dtype>
    void
    HeatmapDataLayer<Dtype>::transform_data_label(Blob<Dtype> *data_blob, Blob<Dtype> *label_blob, const cv::Mat &img,
                                                  float img_width_scale, float img_height_scale) {
        const bool has_mean_file = this->layer_param_.heatmap_data_param().has_mean_file(); //是否有均值文件
        const string mean_file = this->layer_param_.heatmap_data_param().mean_file();//如果有均值文件,获取均值文件的路径
        const bool do_rotation = this->layer_param_.heatmap_data_param().rotation() && Rand(2); //是否进行旋转操作
        const float max_angle = this->layer_param_.heatmap_data_param().max_angle(); //要是进行旋转,旋转的最大角度
        const int mirror_pairs_size = this->layer_param_.heatmap_data_param().mirror_pairs_size(); //要是做镜像,镜像的pair数量
        const int points_num = this->layer_param_.heatmap_data_param().coordinate_num(); //关键点的数量

        std::vector<string> mirror_pairs; //将镜像pair放入vector中
        for (int i = 0; i < mirror_pairs_size; ++i)
        {
            mirror_pairs.push_back(this->layer_param_.heatmap_data_param().mirror_pairs(i));
        }

        const int crop_size = this->layer_param_.heatmap_data_param().crop_size(); //获取裁剪之后的shape
        const int img_channels = img.channels(); //这里的img表示经过resize之后的新image
        const int img_height = img.rows; //高
        const int img_width = img.cols; //宽

        // Check dimensions.
        const int channels = data_blob->channels();
        const int height = data_blob->height();
        const int width = data_blob->width();
        const int num = data_blob->num();
		//验证img与data_blob的各个shape是否一致
        CHECK_EQ(channels, img_channels);
        CHECK_LE(height, img_height);
        CHECK_LE(width, img_width);
        CHECK_GE(num, 1);

        CHECK(img.depth() == CV_8U) << "Image data type must be unsigned byte";


        Blob<Dtype> data_mean;
        BlobProto blob_proto;
        if (has_mean_file)
        {
            ReadProtoFromBinaryFileOrDie(mean_file, &blob_proto); //从均值binary文件中提取均值并转为proto
            data_mean.FromProto(blob_proto); //将proto转为blob类型
        }

        const Dtype scale = this->layer_param_.heatmap_data_param().scale();
        const bool do_mirror = this->layer_param_.heatmap_data_param().mirror() && Rand(2);

        vector<Dtype> mean_values;
        for (int c = 0; c < this->layer_param_.heatmap_data_param().mean_value_size(); ++c)
        {
            mean_values.push_back(this->layer_param_.heatmap_data_param().mean_value(c));
        }
        const bool has_mean_values = mean_values.size() > 0;
		
		//验证进行resize之后的图片的shape是否符合crop的shape
        CHECK_GT(img_channels, 0);
        CHECK_GE(img_height, crop_size);
        CHECK_GE(img_width, crop_size);

        Dtype *mean = NULL;
        if (has_mean_file) //有binary均值文件的时候验证img的shape与均值的shape是否一致
        {
            CHECK_EQ(img_channels, data_mean.channels());
            CHECK_EQ(img_height, data_mean.height());
            CHECK_EQ(img_width, data_mean.width());
            mean = data_mean.mutable_cpu_data();  // 将均值data存放到cpu中,并标记为可写状态
        }
        if (has_mean_values) //当用户输入均值数据的时候
        {
            CHECK(mean_values.size() == 1 || mean_values.size() == img_channels) //均值数据的size要不为1要不为3
            << "Specify either 1 mean_value or as many as channels: " << img_channels;
            if (img_channels > 1 && mean_values.size() == 1)
            {
                // 将均值数据存放到mean_values
                for (int c = 1; c < img_channels; ++c)
                {
                    mean_values.push_back(mean_values[0]);
                }
            }
        }

        ///随机裁剪
        int h_off = 0; //定义高度裁剪大小
        int w_off = 0; //定义宽度裁剪大小
        cv::Mat cv_cropped_img = img;
        if (crop_size) //当用户输入裁剪大小的时候
        {
        	//判断裁剪shape的大小是否与经过resize到256的时候的img的shape一致
            CHECK_EQ(crop_size, height);
            CHECK_EQ(crop_size, width);
            // We only do random crop when we do training.
            if (this->phase_ == TRAIN) //只有在train阶段进行裁剪操作
            {
            	//resize的大小要大于crop_size
                h_off = Rand(img_height - crop_size + 1); //生成裁剪的高度部分的随机数
                w_off = Rand(img_width - crop_size + 1); //生成裁剪的宽度部分的随机数
            } else
            {
            	//当在test阶段的时候,如何进行操作
                h_off = (img_height - crop_size) / 2; //计算裁剪部分高度的一半
                w_off = (img_width - crop_size) / 2; //计算裁剪部分宽度的一半
            }
            cv::Rect roi(w_off, h_off, crop_size, crop_size);
            cv_cropped_img = img(roi);
        } else //当用户没有主动输入裁剪大小的时候
        {
            CHECK_EQ(img_height, height);
            CHECK_EQ(img_width, width);
        }

        CHECK(cv_cropped_img.data);
        cv::Mat cv_rotated_img = cv_cropped_img;
        float angle = 0;
        if (do_rotation) //当进行旋转操作的时候
        {
            angle = Uniform(-max_angle, max_angle); //调用归一化函数
            RotateImage(cv_cropped_img, cv_rotated_img, angle); //进行图片旋转
        }
        CHECK(cv_rotated_img.data);
        //data处理
        Dtype *transform_data = trans_data_tmp_.mutable_cpu_data();
        int top_index;
        for (int h = 0; h < height; ++h)
        {
            const uchar *ptr = cv_rotated_img.ptr<uchar>(h);
            int img_index = 0;
            for (int w = 0; w < width; ++w)
            {
                for (int c = 0; c < img_channels; ++c)
                {
                    if (do_mirror)
                    {
                        top_index = (c * height + h) * width + (width - 1 - w);//Blob进行镜像
                    } else
                    {
                        top_index = (c * height + h) * width + w;
                    }
                    Dtype pixel = static_cast<Dtype>(ptr[img_index++]);
                    if (has_mean_file)
                    {
                        int mean_index = (c * img_height + h_off + h) * img_width + w_off + w;
                        transform_data[top_index] = (pixel - mean[mean_index]) * scale;
                    }
                    else
                    {
                        if (has_mean_values)
                        {
                            transform_data[top_index] = (pixel - mean_values[c]) * scale;
                        } else
                        {
                            transform_data[top_index] = pixel * scale;
                        }
                    }
                }
            }
        }

        //label处理
        //随机剪裁
        int coords_size = (int) lines_[lines_id_].size() - 1; //计算关键点 *2的数量
        CHECK_EQ(coords_size, points_num*2) << "keypoints must has has points_nums points.";
        vector<cv::Point> points(points_num); //创建一个存放关键点坐标的vector,其中每个元素为一个关键点的xy坐标值
        for (int i = 0; i < points_num; ++i) //对每个点进行label处理
        {
            vector<string> numstr_split;
            numstr_split.push_back(lines_[lines_id_][2*i + 1]);
            numstr_split.push_back(lines_[lines_id_][2*i + 2]);
            CHECK_EQ(numstr_split.size(), 2) << "coordinate param has two numbers.";
            int x_ori = std::atoi(numstr_split[0].c_str()); //转成int类型
            int y_ori = std::atoi(numstr_split[1].c_str()); //转成int类型
            //如果关键点的值小于0不符合条件的时候,赋值-1,否则:
            // 由于x_ori,y_ori是针对于图片的原始shape,将其转为resize到(256,256)对应的点坐标,若有裁剪,再接上裁剪操作
            int x = x_ori>0 ? (int)roundf(x_ori * img_width_scale) - w_off : -1;
            int y = y_ori>0 ? (int)roundf(y_ori * img_height_scale) - h_off : -1;
            points[i].x = x;
            points[i].y = y;
        }
        //旋转 getRotationMatrix2D 与实际方向相反
        float arc_angle = angle * M_PI / -180;
        if (do_rotation) //如果做旋转,那么将其对应的标签进行旋转处理
        {
        	//计算中心店坐标
            float center_x = width / 2.0f;
            float center_y = height / 2.0f;
            for (int i = 0; i < points_num; ++i)
            {
                float r_x = points[i].x - center_x;
                float r_y = points[i].y - center_y;
                int x_rotated = (int)roundf(center_x + r_x * std::cos(arc_angle) - r_y * std::sin(arc_angle));
                int y_rotated = (int)roundf(center_y + r_y * std::cos(arc_angle) + r_x * std::sin(arc_angle));
                points[i].x = x_rotated;
                points[i].y = y_rotated;
            }
        }
        //处理crop和rotation使得关键点被剔除(如果被剔除,该点的坐标值设置为-1)
        for (int i = 0; i < points_num; ++i)
        {
            int ptx = points[i].x;
            int pty = points[i].y;
            if (ptx >= 0 && ptx < width && pty >= 0 && pty < height)
            {
//                if(points[i].x==0)
//                {
//                    std::cout<<"处理crop和rotation使得关键点被剔除"<<points[i].x<<std::endl;
//                }
                points[i].x = ptx;
                points[i].y = pty;
            } else
            {
                points[i].x = -1;
                points[i].y = -1;

            }
        }
        //镜像
        if (do_mirror)
        {
            for (int i = 0; i < points_num; ++i)
            {
                if (points[i].x != -1 && points[i].y != -1)
                {
                    points[i].x = width - 1 - points[i].x;
                }
            }
            //关键点重新排序
            std::vector<std::string> kpt;
            for (int i = 0; i < mirror_pairs_size; ++i)
            {
                std::string &pair = mirror_pairs[i];
                kpt.clear();
                boost::split(kpt, pair, boost::is_any_of(" ,"), boost::token_compress_on);
                CHECK_EQ((int) kpt.size(), 2) << "mirror_pair must has two element.";
                int one_idx = std::atoi(kpt[0].c_str());
                int ano_idx = std::atoi(kpt[1].c_str());
                int x_tmp = points[one_idx].x;
                int y_tmp = points[one_idx].y;
                points[one_idx].x = points[ano_idx].x;
                points[one_idx].y = points[ano_idx].y;
                points[ano_idx].x = x_tmp;
                points[ano_idx].y = y_tmp;
            }
        }
        //生成heatmap
        const int hm_size = this->layer_param_.heatmap_data_param().label_size();//最后生成的heatmap的shape
        //最后生成的heatmap的shape基于256输入shape的比值,这里为(1/4=0.25)
        float hm_width_scale = (float) hm_size / width;
        float hm_height_scale = (float) hm_size / height;
        vector<cv::Point> hm_pts(points.begin(), points.end());
        for (int i = 0; i < points_num; ++i)
        {
            if (hm_pts[i].x < 0 || hm_pts[i].y < 0) //如果生成热力图的点坐标小于0不符合条件,跳过此步骤
            {
                continue;
            }
//            std::cout<<"hm_pts[i]"<<hm_pts[i].x<<" "<<hm_pts[i].y<<std::endl;
			//上述生成的hm_pts中坐标值都是基于256的,现在进行转化为64*64的坐标值
            hm_pts[i].x = (int)roundf(hm_width_scale * hm_pts[i].x);
            hm_pts[i].y = (int)roundf(hm_height_scale * hm_pts[i].y);
//            std::cout<<"hm_pts[i]"<<hm_pts[i].x<<" "<<hm_pts[i].y<<std::endl<<std::endl;
        }
        generate_hm(hm_pts, label_blob); //调用生成热力图函数
    }

四、generate_hm函数

template<typename Dtype>
    void HeatmapDataLayer<Dtype>::generate_hm(const vector<cv::Point> &hm_pts, Blob<Dtype> *label_blob) 
/*hm_pts:存放进行处理的label关键点坐标的vector,size就是关键点的数量,其中每个元素就是经过处理的每个关键点坐标值
label_blob:batch_size * points_num * 64 * 64 的一个多维数组
*/
{
        CHECK(hm_pts.size() == label_blob->channels()) << "heatmap points are same with channels of label_blob";
        CHECK(label_blob->shape().size() == 4) << "label Blob must has 4 dimension.";
        CHECK(label_blob->num() == 1) << "label Blob .num() must be 1.";
        const int label_height = label_blob->height(); //64
        const int label_width = label_blob->width(); //64
        const int label_num_channels = label_blob->channels(); //points_num 
        const int label_channel_size = label_height * label_width; // 64 *64
        caffe_set(label_blob->count(),Dtype(0),label_blob->mutable_cpu_data());
        Dtype *label_ptr = label_blob->mutable_cpu_data();
        
        float sigma = 2;
        for (int idx_ch = 0; idx_ch < label_num_channels; idx_ch++) //遍历每个点
        {
            if (hm_pts[idx_ch].x < 0 || hm_pts[idx_ch].y < 0) //关键点坐标值小于0的时候......
            {
                for (int i = 0; i < label_height; i++)
                {
                    for (int j = 0; j < label_width; j++)
                    {
                        // 计算索引
                        int label_idx = idx_ch * label_channel_size + i * label_width + j;
                        label_ptr[label_idx] = Dtype(-1);
                    }
                }
            }
            else
            {
                for (int i = 0; i < label_height; i++)
                {
                    for (int j = 0; j < label_width; j++)
                    {
                        // 计算索引
                        int label_idx = idx_ch * label_channel_size + i * label_width + j;
                        float gaussian = (float)exp(-0.5 * (pow(i - hm_pts[idx_ch].y, 2) + pow(j - hm_pts[idx_ch].x, 2)) * pow(1 / sigma, 2)); //高斯热力图计算公式
                        if(gaussian>0.0001)
                        {
                            label_ptr[label_idx] = gaussian;//赋值
                        }
                    }
                }
            }
        }
    }

五、Rand函数(随机生成一个数)

    template<typename Dtype>
    int HeatmapDataLayer<Dtype>::Rand(int n) {
        shared_ptr<Caffe::RNG> rng_;
        const unsigned int rng_seed = caffe_rng_rand();
        rng_.reset(new Caffe::RNG(rng_seed));
        CHECK_GT(n, 0);
        caffe::rng_t *rng = static_cast<caffe::rng_t *>(rng_->generator());
        return ((*rng)() % n);
    }

六、RotateImage函数(图片旋转函数)

    template<typename Dtype>
    cv::Mat HeatmapDataLayer<Dtype>::RotateImage(cv::Mat &src, cv::Mat &dst, float angle) {
        cv::Point center(src.cols / 2, src.rows / 2);
        double scale = 1;
        // Get the rotation matrix with the specifications above
        cv::Mat rot_mat = cv::getRotationMatrix2D(center, angle, scale); //可以看出来,本质还是调用opencv接口
        // Rotate the warped image
        cv::warpAffine(src, dst, rot_mat, src.size());
        return rot_mat;
    }

七、Uniform函数(归一化函数)

eg.在本cpp文件中,传入的是最大旋转值与它的负数:-30、30,这个函数是随机在(-30 ,30)这个角度中选择一个旋转角度

    template<typename Dtype>
    float HeatmapDataLayer<Dtype>::Uniform(const float min, const float max) {
        float random = ((float) rand()) / (float) RAND_MAX;
        float diff = max - min;
        float r = random * diff;
        return min + r;
    }
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

进我的收藏吃灰吧~~

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值