我的错误出现在
h = torch.sum(h, dim=[2, 3])
而h的类型为
tensor([], device='xla:1', dtype=torch.float16, grad_fn=<ReluBackward0>)
因为xla在TPU上运算需要使用 bfloat16
所以更改h的类型为bfloat16
h = h.type(torch.bfloat16)
h = torch.sum(h, dim=[2, 3])
我的错误出现在
h = torch.sum(h, dim=[2, 3])
而h的类型为
tensor([], device='xla:1', dtype=torch.float16, grad_fn=<ReluBackward0>)
因为xla在TPU上运算需要使用 bfloat16
所以更改h的类型为bfloat16
h = h.type(torch.bfloat16)
h = torch.sum(h, dim=[2, 3])