今天pytorch官网更新了pytorch2.0稳定版,迫不及待的我直接更新了,确实像官方所说,只需加入model=torch.compile(model)一行代码即可加速,加入的位置如下。
cpu训练:
model=UNet(deep_supervision=True)
model=torch.compile(model)
单卡训练:
model=UNet(deep_supervision=True)
model.to(Device)
model=torch.compile(model)
多卡训练:
model=UNet(deep_supervision=True)
model.to(Device)
model=nn.parallel.DistributedDataParallel(
model,
device_ids=[local_rank],
output_device=local_rank,
broadcast_buffers=False,
)
model=torch.compile(model)
注意 model = torch.compile(model) 这句话的位置对了就可以了,其他的不用改!!