Tensorflow API 源码检测模型的基类(一)

Tensorflow API之 Meta Architecture(Detection Model)(一)检测模型的基类

models/research/object_detection/core/model.py

1、 模型框架内容

"""Abstract detection model.
这个文件为检测模型定义了一个通用基类。设计用于任意检测模型的程序应该只依赖于这个类。
我们打算让这个类中的函数遵循张量入/出的设计,因此所有函数都有张量或包含张量的列表/字典作为输入和输出。

Training time:训练时间
inputs (images tensor) -> preprocess(预处理) -> predict(预测) -> loss(损失函数) -> outputs (loss tensor)
训练时间:输入(图像张量)->预处理→预测→损失→张量输出(损失)
Evaluation time:
inputs (images tensor) -> preprocess -> predict -> postprocess(后处理,计算精度)
 -> outputs (boxes tensor, scores tensor, classes tensor, num_detections tensor)
因此检测模型必须实现四种功能(1)预处理,(2)预测,(3)后处理和(4)损失。检测模型不应该对输入
的大小或高宽比做任何假设——他们负责做任何必要的大小调整/重塑。输出类总是范围[0,num类)内的整数。
这些整数到语义标签的任何映射都要在这个类之外处理
Images are resized in the `preprocess` method. All of `preprocess`, `predict`,
and `postprocess` should be reentrant.

The `preprocess` method runs `image_resizer_fn` that returns resized_images(缩放的图片) and
`true_image_shapes(真实图像的形状)`. Since `image_resizer_fn` can pad the images with zeros(这是有用的填充图像,为批处理的固定大小。),
true_image_shapes indicate the slices that contain the image without padding.(真实图像形状表示包含没有填充的图像的切片。)
This is useful for padding images to be a fixed size for batching.
The `postprocess` method uses the true image shapes to clip predictions that lie
outside of images.(后处理方法使用真实的图像形状剪辑图像之外的预测)

2、主要的几个函数

1、init
 def __init__(self, num_classes):
    """Constructor.

    Args:
      num_classes: number of classes.  Note that num_classes *does not* include
      background categories that might be implicitly predicted in various
      implementations.
      初始化_num_classes为指定值;_groundtruth_lists为空字典
    """
    self._num_classes = num_classes
    self._groundtruth_lists = {}
2、 preprocess
def preprocess(self, inputs):
    """Input preprocessing.

    To be overridden by implementations.
    这个函数负责在对输入图像运行检测器之前,对输入值的任何缩放/移位负责。它还负责任何大小调整,
    填充可能是必要的,因为图像被假定到达任意大小。
    虽然这个函数可能是预测方法的一部分(下面),它通常是方便的保持这些分开——例如,
    我们可能想要预处理一个设备,放在一个队列,并让另一个设备(例如,GPU)处理预测。
    Args:
      inputs: a [batch, height_in, width_in, channels] float32 tensor
        representing a batch of images with values between 0 and 255.0.

    Returns:
      preprocessed_inputs: a [batch, height_out, width_out, channels] float32
        tensor representing a batch of images.
      true_image_shapes: int32 tensor of shape [batch, 3] where each row is
        of the form [height, width, channels] indicating the shapes
        of true images in the resized images, as resized images can be padded
        with zeros.
    """
    pass
3、 predict
 def predict(self, preprocessed_inputs, true_image_shapes):
    """Predict prediction tensors from inputs tensor.
    Outputs of this function can be passed to loss or postprocess functions.

    Args:
      preprocessed_inputs: a [batch, height, width, channels] float32 tensor
        representing a batch of images.预处理的输出
      true_image_shapes: int32 tensor of shape [batch, 3] where each row is
        of the form [height, width, channels] indicating the shapes
        of true images in the resized images, as resized images can be padded
        with zeros.

    Returns:
      prediction_dict: a dictionary holding prediction tensors to be
        passed to the Loss or Postprocess functions.(预测字典)
    """
    pass

4、loss
 def loss(self, prediction_dict, true_image_shapes):
    """Compute scalar loss tensors with respect to provided groundtruth.

    Calling this function requires that groundtruth tensors have been
    provided via the provide_groundtruth function.(需要调用provide_groundtruth函数提供 groundtruth tensors张量)

    Args:
      prediction_dict: a dictionary holding predicted tensors  #预测张量字典
      true_image_shapes: int32 tensor of shape [batch, 3] where each row is
        of the form [height, width, channels] indicating the shapes
        of true images in the resized images, as resized images can be padded
        with zeros.

    Returns:
      a dictionary mapping strings (loss names) to scalar tensors representing
        loss values.
    """
    pass

 def provide_groundtruth(self,
                          groundtruth_boxes_list,
                          groundtruth_classes_list,
                          groundtruth_masks_list=None,
                          groundtruth_keypoints_list=None,
                          groundtruth_weights_list=None,
                          groundtruth_is_crowd_list=None):
    """Provide groundtruth tensors.
    Args:
      groundtruth_boxes_list: a list of 2-D tf.float32 tensors of shape
        [num_boxes, 4] containing coordinates of the groundtruth boxes.
    1、Groundtruth boxes are provided in [y_min, x_min, y_max, x_max]#真实的检测框
          
          format and assumed to be normalized and clipped(被标准化和裁剪)
          relative to the image window with y_min <= y_max and x_min <= x_max.
    2、groundtruth_classes_list: a list of 2-D tf.float32 one-hot (or k-hot)
        tensors of shape [num_boxes, num_classes] containing the class targets
        with the 0th index assumed to map to the first non-background class.
      groundtruth_masks_list: a list of 3-D tf.float32 tensors of
        shape [num_boxes, height_in, width_in] containing instance
        masks with values in {0, 1}.  If None, no masks are provided.
        Mask resolution `height_in`x`width_in` must agree with the resolution
        of the input image tensor provided to the `preprocess` function.
      groundtruth_keypoints_list: a list of 3-D tf.float32 tensors of
        shape [num_boxes, num_keypoints, 2] containing keypoints.
        Keypoints are assumed to be provided in normalized coordinates and
        missing keypoints should be encoded as NaN.
      groundtruth_weights_list: A list of 1-D tf.float32 tensors of shape
        [num_boxes] containing weights for groundtruth boxes.
      groundtruth_is_crowd_list: A list of 1-D tf.bool tensors of shape
        [num_boxes] containing is_crowd annotations
    """
    self._groundtruth_lists[fields.BoxListFields.boxes] = groundtruth_boxes_list
    self._groundtruth_lists[
        fields.BoxListFields.classes] = groundtruth_classes_list
    if groundtruth_weights_list:
      self._groundtruth_lists[fields.BoxListFields.
                              weights] = groundtruth_weights_list
    if groundtruth_masks_list:
      self._groundtruth_lists[
          fields.BoxListFields.masks] = groundtruth_masks_list
    if groundtruth_keypoints_list:
      self._groundtruth_lists[
          fields.BoxListFields.keypoints] = groundtruth_keypoints_list
    if groundtruth_is_crowd_list:
      self._groundtruth_lists[
          fields.BoxListFields.is_crowd] = groundtruth_is_crowd_list

5、postprocess(后处理)
def postprocess(self, prediction_dict, true_image_shapes, **params):
    """Convert predicted output tensors to final detections.

    Outputs adhere to the following conventions:
    * Classes are integers in [0, num_classes); background classes are removed
      and the first non-background class is mapped to 0. If the model produces
      class-agnostic detections, then no output is produced for classes.
    * Boxes are to be interpreted as being in [y_min, x_min, y_max, x_max]
      format and normalized relative to the image window.
    * ' num_detections '用于将detections填充为a的设置固定数量的盒子
    * We do not specifically assume any kind of probabilistic interpretation
      of the scores --- the only important thing is their relative ordering.
      Thus implementations of the postprocess function are free to output
      logits, probabilities, calibrated probabilities, or anything else.
    因此,后处理函数的实现可以自由地输出日志、概率、校准后的概率或其他任何东西。
    Args:
      prediction_dict: a dictionary holding prediction tensors.
      true_image_shapes: int32 tensor of shape [batch, 3] where each row is
        of the form [height, width, channels] indicating the shapes
        of true images in the resized images, as resized images can be padded
        with zeros.
      **params: Additional keyword arguments for specific implementations of
        DetectionModel.

    Returns:
        detections: a dictionary containing the following fields
        detection_boxes: [batch, max_detections, 4]
        detection_scores: [batch, max_detections]
        detection_classes: [batch, max_detections]
          (If a model is producing class-agnostic detections, this field may be
          missing)
        instance_masks: [batch, max_detections, image_height, image_width]
          (optional)
        keypoints: [batch, max_detections, num_keypoints, 2] (optional)
        num_detections: [batch]
    """
    pass
6、restore_map(训练模型时模型恢复)
 def restore_map(self, fine_tune_checkpoint_type='detection'):
    """Returns a map of variables to load from a foreign checkpoint.

    Returns a map of variable names to load from a checkpoint to variables in
    the model graph. This enables the model to initialize based on weights from
    another task. For example, the feature extractor variables from a
    classification model can be used to bootstrap training of an object
    detector. When loading from an object detection model, the checkpoint model
    should have the same parameters as this detection model with exception of
    the num_classes parameter.返回要从检查点加载的变量名映射到模型图中的变量。这
    使得模型可以根据另一个任务的权重进行初始化。例如,分类模型中的特征提取器变量可以用于引导对象检测器的训练。从对象检测模型加载时,检查点模型应该具有与该检测模型相同的参数,但num_classes参数除外。

    Args:
      fine_tune_checkpoint_type: whether to restore from a full detection
        checkpoint (with compatible variable names) or to restore from a
        classification checkpoint for initialization prior to training.
        Valid values: `detection`, `classification`. Default 'detection'.

    Returns:
      A dict mapping variable names (to load from a checkpoint) to variables in
      the model graph.
    """
    pass
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值