social-gan可视化
social-gan的源码并不提供可视化代码,需要可视化可以查看以下大佬文章。下面的文章很细,同时提供github源码。但是有一个地方可能有一点小问题。我进行了一些改进。
Social GAN——可视化
注意
上面的文章实现了4条轨迹同时绘制,但是这4条轨迹并不是同一空间同一时刻下的四条轨迹。所以可视化的结果并不具备实际意义。
主要是文章中代码的取数据方式存在错误:
源代码如下:
for batch in loader:
nn+=1
batch = [tensor.cuda() for tensor in batch]
(obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel,
non_linear_ped, loss_mask, seq_start_end) = batch
for _ in range(1):#num_samples
pred_traj_fake_rel = generator(
obs_traj, obs_traj_rel, seq_start_end
)
pred_traj_fake = relative_to_abs(
pred_traj_fake_rel, obs_traj[-1]
)
gt=