Robotics Diffusion Transformer,简称 RDT,是一个基于扩散模型的多模态具身智能模型。本篇在上一篇的基础之上梳理一下模型训练和推理过程中,数据集的处理流程。接上一篇:
RDT模型训练过程中输入数据中,有2个重要的数据就是state和action,对应上一篇架构图中的
Z
t
Z_{t}
Zt和
a
t
a_{t}
at,Zt在论文中英文名是proprioception,中文翻译过来是“本体感知”,就是一个128维的向量。每个维度的是什么含义参考上一篇文章。作者的目标是将不同型号的具身智能机器人动作(例如有双臂的,单臂的,A品牌的,B品牌的…)统一到一个128维的动作空间中。这样可以将历史上公开的数据集都统一的使用起来。
举例来讲,在Open X-Embodiment数据集中,CLVR Jaco Play子数据集采集了机械臂的End effector的x,y,z,各坐标轴的旋转角,还有各坐标轴的角速度。而Droid数据集除了End effector的x,y,z,各坐标轴的旋转角外,还采集了各joint(关节)的角度。作者将这些不同的字段都统一到了128维空间中,128个字段足以表达各种数据集的不同。
作者的训练pipeline分两种,一个是pretrain的流程和一个finetune的流程,pretrain就是从零开始训练,用大量的数据。finetune就是微调,用少量的数据。
pretrain
作者pretrain使用了46个数据集,总计100万+条轨迹,21TB的数据。因为数据量较大,在训练过程中动态的加载处理可能会影响训练速度,所以作者在工程上使用的producer&consumer的生产者&消费者的模式。
下面讲解过程中会有episode和step两个概念。例如将瓶子放在垃圾桶中这个指令,采集了5组数据,每组数据都包含完整的机械臂运动至瓶子旁边,用夹子夹住,然后提起放在垃圾桶中这一套完整的动作叫一个episode,一般一个相同的指令需要采集多个episode。每个episode一般包含很多采样,例如可以采样180条数据,每条数据都包含那一采样时刻的state, action, image, text指令等全部信息。
- producer&consumer
作者在硬盘上设计了一个buffer(缓冲区),buffer包括512个chunk,每个chunk可容纳512条数据(对应一个step)。每个chunk中有一个512字节的文件(dirty_bit),每个字节代表这个chunk的512条数据是否已经被consumer消费。若值为1,就是代表被消费了,producer可以再生产一条数据覆盖这个“脏”数据。producer使用多线程加速生产数据,所以可以在文件夹中看到大量的锁,以保证不同的producer不会写同一个数据,或者是消费者消费还没有生产完成的数据。同理 ,消费者只会消费dirty_bit=0的那条数据。整体上缓冲区可容纳512*512条数据,生产者会将不同数据集打乱后,不断的放在这个buffer中。
(上面的数字是参数,都可配置)
buffer文件结构如下:
.
├── chunk_0
│ ├── dirty_bit
│ ├── json_content_0.json
│ ├── sample_0.npz
│ ...
├── chunk_1
│ ├── dirty_bit
│ ├── json_content_0.json
│ ├── sample_0.npz
│ ...
...
每条数据有2个文件,例如针对第一条数据,第1个文件名是json_content_0.json,内容示例如下:
{
"dataset_name": "cmu_stretch",
"#steps": 180,#此episode共有180个step
"instruction": "Lift up a lid from the pot."
}
第2个文件名是sample_0.npz,内容示例如下(仅关注key即可,部分注释如下):
{
'step_id': episode['step_id'][i], #当前episode的哪一个step
'state_chunk': past_states[i], #64条历史state,训练的时候只使用当前时刻那一条
'state_chunk_time_mask': past_states_time_mask[i],#并不是每一步都有历史state,例如step_id<64的step的历史state就不够64条,所以用一个mask来标记是否有效
'action_chunk': future_states[i],#64条将来state,action_chunk其实是使用了将来的state作为action
'action_chunk_time_mask': future_states_time_mask[i],#同state_chunk_time_mask
'state_vec_mask': masks[i],#state的128维的哪个字段是有效的
'past_frames_0': episode['past_frames_0'][i],#每条数据默认有4*2张图片输入,frame_0就是第1*2张,frame_0包含2张图片,一张是当前,一张是上一时刻。
'past_frames_0_time_mask': episode['past_frames_0_time_mask'][i],#并不是每条数据都有上一时刻的照片,用mask标记
'past_frames_1': episode['past_frames_1'][i],#同frame_0,若只有一个相机,那么frame_1就是空的
'past_frames_1_time_mask': episode['past_frames_1_time_mask'][i],#同frame_0
'past_frames_2': episode['past_frames_2'][i],#同frame_0
'past_frames_2_time_mask': episode['past_frames_2_time_mask'][i],#同frame_0
'past_frames_3': episode['past_frames_3'][i],#同frame_0
'past_frames_3_time_mask': episode['past_frames_3_time_mask'][i],#同frame_0
'state_std': state_std[i],#state的统计信息
'state_mean': state_mean[i],#state的统计信息
'state_norm': state_norm[i],#state的统计信息
}
- 不同数据集如何在一个框架中处理
每个数据集都需要提供一个脚本来将数据格式统一,脚本需要提供process_step函数,示例如下:
def process_step(step: dict) -> dict:
# Concatenate the action
arm_action = eef_delta_pos
action['arm_concat'] = arm_action
# base_action = tf.constant([0, 0, 0, 0], dtype=tf.float32)
# action['base_concat'] = None
# Write the action format
action['format'] = tf.constant(
"eef_delta_pos_x,eef_delta_pos_y,eef_delta_pos_z")
# Convert raw state to our state
state = step['observation']
joint_pos = state['joint_pos']
eef_pos = state['end_effector_cartesian_pos'][:3]
eef_quat = state['end_effector_cartesian_pos'][3:]
eef_ang = quaternion_to_rotation_matrix(eef_quat)
eef_ang = rotation_matrix_to_ortho6d(eef_ang)
eef_vel = state['end_effector_cartesian_velocity'][:3]
# We do not use angular velocity since it is very inaccurate in this environment
# eef_angular_vel = state['end_effector_cartesian_velocity'][3:]
# Concatenate the state
state['arm_concat'] = tf.concat([joint_pos, eef_pos, eef_ang, eef_vel], axis=0)
# Write the state format
state['format'] = tf.constant(
"arm_joint_0_pos,arm_joint_1_pos,arm_joint_2_pos,arm_joint_3_pos,arm_joint_4_pos,arm_joint_5_pos,gripper_joint_0_pos,gripper_joint_1_pos,eef_pos_x,eef_pos_y,eef_pos_z,eef_angle_0,eef_angle_1,eef_angle_2,eef_angle_3,eef_angle_4,eef_angle_5,eef_vel_x,eef_vel_y,eef_vel_z")
process_step函数核心的是最下面的那两行,arm_concat就是把需要的字段组合在一起,format就指示了包含哪些字段,以逗号分隔,每个字符串都是128维统一空间的一个key,例如arm_joint_0就代表第一个关节的角度。每个数据集都提供这个函数,然后外围框架就用这些信息把它统一到128维空间中。
- 训练过程中的action真值从哪里来
大家看到上面代码中的action[‘format’]可能会觉得是这个字段,但实际上其实并没有使用这个字段,可以参考flatten_episode函数中的以下2行代码,action其实是使用了将来的state作为action。作者解释:for each robot, we can use a single space to accommodate both its proprioception Z t Z_{t} Zt and action a t a_{t} at. This is because a t a_{t} at is usually a subset of the desired Z t Z_{t} Zt 。另一方面(个人见解),不同机械臂的控制信号差异性会比较大,例如有的是用end effector坐标xyz控制,有的是用joint角度控制,有的是用xyz或joint角度的增量控制,有的还需要输入速度/角速度控制信号,可能会更加复杂一些。
'action_chunk': future_states[i],
'action_chunk_time_mask': future_states_time_mask[i],
finetune
finetune一般是少量数据在上面pretrain的模型基础之上提升某个特定场景或机械臂的效果。上面pretrain使用的数据格式都是rlds格式,存储的文件是以record形式存储,用google的tfds库来读取,因为Open X-Embodiment用的是这种格式,所以作者也将其它非这种格式的转换成了这种格式。但在finetune中,作者使用了另外一种数据格式叫hdf5,下面这篇文件中讲到了这些数据格式,可以参考:
在finetune的过程中,就没有使用pretrain所使用的producer&consumer模式了,而是直接从hdf5中加载数据,我想是因为数据量小的原因吧。另外,在这里使用了数据集中的action原始数据,这点与pretrain也不太一样,相当于给大家finetune时更多的修改空间。但大家若finetune需要注意action空间要跟state对齐。
因为我测试的时候,手头没有可用的hdf5数据,所以我写了个脚本,把常见的tlds格式转换为hdft,放在这里可供参考。
import cv2
import h5py
import numpy as np
import tensorflow_datasets as tfds
from PIL import Image
import os
display_key = 'image'
datasets_name = "cmu_stretch"
b = tfds.builder_from_directory(f"/home/ubuntu/Downloads/openvla/austin_buds_dataset_converted_externally_to_rlds/0.1.0")
ds = b.as_dataset(split='train') # 具体可以根据需求改
output_dir = f'/home/ubuntu/Downloads/rdt/austin_buds_dataset_converted_externally_to_rlds_hdf5/'
os.makedirs(output_dir, exist_ok=True)
def images_encoding(imgs):
encode_data = []
padded_data = []
max_len = 0
for i in range(len(imgs)):
success, encoded_image = cv2.imencode('.png', imgs[i].numpy())
jpeg_data = encoded_image.tobytes()
encode_data.append(jpeg_data)
# encode_data.append(np.frombuffer(jpeg_data, dtype='S1'))
max_len = max(max_len, len(jpeg_data))
# padding
for i in range(len(imgs)):
padded_data.append(encode_data[i].ljust(max_len, b'\0'))
return encode_data, max_len
instructions_file_path = os.path.join(output_dir, f'instruction.txt')
state_file_path = os.path.join(output_dir, f'state.txt')
# 遍历数据集
for idx, episode in enumerate(ds):
# 为每个视频创建一个文件夹
#video_folder = os.path.join(output_dir, f'video_{idx}')
#os.makedirs(video_folder, exist_ok=True)
# 提取该视频的所有帧
frames = episode['steps']
# 遍历每一帧并保存
state_list = []
# 存储hdf5要使用的数据
qpos = []
actions = []
cam_high = []
cam_right_wrist = []
past_state = np.zeros(7)
for frame_idx, step in enumerate(frames):
state = step['observation']["state"]
state = np.array(state) # x,y,z,rx,ry,rz
# state = np.append(state, data["gripper"]) # 添加一个张开度0~1
state = state.astype(np.float32)
pos = state[:6]
pos = np.append(pos, state[7])
qpos.append(pos)
# 每个数据集image的特征名字不一样,具体要看数据集下载好后的 features.json 文件中对应的字段是什么
image = step['observation']['image'] # fractal20220817_data
wrist_image = step['observation']['wrist_image'] # fractal20220817_data
# image = step['observation']["agentview_rgb"] # viola
# image = step['observation']["image"] # bridge
# 获取自然语言指令,具体要看数据集下载好后的 features.json 文件对应的字段是什么
natural_language_instruction = step["language_instruction"].numpy().decode('utf-8') # for ucsd、berkeley_fanuc_manipulation
#natural_language_instruction = step['observation']["natural_language_instruction"].numpy().decode('utf-8')
#state_list.append(step['observation']["state"])
# 将图像转换为 PIL 格式
#image_pil = Image.fromarray(image.numpy())
# 保存图像,文件名格式为 frame_{frame_idx}.png
#output_path = os.path.join(video_folder, f'frame_{frame_idx}.png')
#image_pil.save(output_path)
if frame_idx == 0:
pass
elif frame_idx == len(frames) - 1:
action = state - past_state
action_new = action[:6]
action_new = np.append(action_new, action[7])
actions.append(action_new)
actions.append(action_new) # 最后一次轨迹没有预测,就用最后一次的轨迹本身作为预测
else:
action = state - past_state
action_new = action[:6]
action_new = np.append(action_new, action[7])
actions.append(action_new)
cam_high.append(wrist_image)
cam_right_wrist.append(image)
past_state = state
hdf5path = os.path.join(output_dir, f'episode_{idx}.hdf5')
with h5py.File(hdf5path, 'w') as f:
f.create_dataset('action', data=np.array(actions))
obs = f.create_group('observations')
image = obs.create_group('images')
obs.create_dataset('qpos', data=qpos)
# 图像编码后按顺序存储
cam_high_enc, len_high = images_encoding(cam_high)
cam_right_wrist_enc, len_right = images_encoding(cam_right_wrist)
image.create_dataset('cam_high', data=cam_high_enc, dtype=f'S{len_high}')
image.create_dataset('cam_right_wrist', data=cam_right_wrist_enc, dtype=f'S{len_right}')
with open(instructions_file_path, 'a') as f:
f.write(f"{natural_language_instruction}\n")