网络训练生成飞行轨迹
介绍
本文是博客的子文。介绍如何对网络进行训练。主要代码文件如下:
整体代码框图如下,本文介绍框图中的B模块:
网络结构
网络结构定义代码在 nets.py中,模型名称为 PlaNet.
参考整体框图中的B模块。代码中_internal_call函数,实现了B模块的流程。
网络默认不输入RGB图像,仅输入SGM深度图像和飞机的位姿,以及给定轨迹方向。
因此,_internal_call函数分别对输入数据进行处理:
- 首先处理飞机的位姿和期望方向,对应B框图中的States Backbone
imu_embeddings = self._imu_branch(imu_obs)。 - 再处理输入的图像数据,对应B框图中的Depth Backbone
img_embeddings = self._preprocess_frames(inputs) - 整合飞机位姿和图像数据,
total_embeddings = tf.concat((img_embeddings, imu_embeddings), axis=-1) - 网络计算
output = self._plan_branch(total_embeddings)
def _internal_call(self, inputs):
if self.config.use_position:
imu_obs = inputs['imu']
else:
# always pass z
imu_obs = inputs['imu'][:, :, 3:]
if (n