在pytorch的分布式训练中,每个卡都会有一个模型(replicate步骤),以及分配的输入(scatter步骤),最后再把每个模型的输出合并(gather步骤),如果每个模型输出的维度不一致的话,是无法gather的。
因此,查看模型return的值,确实是在根据场景实时变化的。其会根据各个样本中具体场景而发生变化,而不同的卡上输出tensor维度不一样,所以无法gather。
报错虽然出现在底层,但是问题本身还是在于模型。在改掉变化的部分之后能够正常运行。
在pytorch的分布式训练中,每个卡都会有一个模型(replicate步骤),以及分配的输入(scatter步骤),最后再把每个模型的输出合并(gather步骤),如果每个模型输出的维度不一致的话,是无法gather的。
因此,查看模型return的值,确实是在根据场景实时变化的。其会根据各个样本中具体场景而发生变化,而不同的卡上输出tensor维度不一样,所以无法gather。
报错虽然出现在底层,但是问题本身还是在于模型。在改掉变化的部分之后能够正常运行。